From dabb3a0962a5791ee00cfb11dc9016fde86cdc1a Mon Sep 17 00:00:00 2001 From: Jinesh Jain <132554375+Jineshbansal@users.noreply.github.com> Date: Wed, 20 Aug 2025 06:50:37 +0530 Subject: [PATCH] FIXME: Make TProtocol Optional[TProtocol] to keep types consistent (#770) * FIXME: Make TProtocol Optional[TProtocol] to keep types consistent * correct test case of test_protocol_muxer * add newsfragment * unit test added --------- Co-authored-by: Manu Sheel Gupta --- libp2p/host/basic_host.py | 7 ++ libp2p/protocol_muxer/multiselect.py | 11 +- libp2p/protocol_muxer/multiselect_client.py | 6 +- .../multiselect_communicator.py | 5 +- libp2p/security/security_multistream.py | 7 +- libp2p/stream_muxer/muxer_multistream.py | 7 +- newsfragments/770.internal.rst | 1 + .../protocol_muxer/test_negotiate_timeout.py | 102 ++++++++++++++++-- .../protocol_muxer/test_protocol_muxer.py | 38 +++++++ 9 files changed, 167 insertions(+), 17 deletions(-) create mode 100644 newsfragments/770.internal.rst diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 70e41953..b40b0128 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -295,6 +295,13 @@ class BasicHost(IHost): ) await net_stream.reset() return + if protocol is None: + logger.debug( + "no protocol negotiated, closing stream from peer %s", + net_stream.muxed_conn.peer_id, + ) + await net_stream.reset() + return net_stream.set_protocol(protocol) if handler is None: logger.debug( diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 8d311391..287a01f3 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -48,12 +48,11 @@ class Multiselect(IMultiselectMuxer): """ self.handlers[protocol] = handler - # FIXME: Make TProtocol Optional[TProtocol] to keep types consistent async def negotiate( self, communicator: IMultiselectCommunicator, negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, - ) -> tuple[TProtocol, StreamHandlerFn | None]: + ) -> tuple[TProtocol | None, StreamHandlerFn | None]: """ Negotiate performs protocol selection. @@ -84,14 +83,14 @@ class Multiselect(IMultiselectMuxer): raise MultiselectError() from error else: - protocol = TProtocol(command) - if protocol in self.handlers: + protocol_to_check = None if not command else TProtocol(command) + if protocol_to_check in self.handlers: try: - await communicator.write(protocol) + await communicator.write(command) except MultiselectCommunicatorError as error: raise MultiselectError() from error - return protocol, self.handlers[protocol] + return protocol_to_check, self.handlers[protocol_to_check] try: await communicator.write(PROTOCOL_NOT_FOUND_MSG) except MultiselectCommunicatorError as error: diff --git a/libp2p/protocol_muxer/multiselect_client.py b/libp2p/protocol_muxer/multiselect_client.py index a5b35006..90adb251 100644 --- a/libp2p/protocol_muxer/multiselect_client.py +++ b/libp2p/protocol_muxer/multiselect_client.py @@ -134,8 +134,10 @@ class MultiselectClient(IMultiselectClient): :raise MultiselectClientError: raised when protocol negotiation failed :return: selected protocol """ + # Represent `None` protocol as an empty string. + protocol_str = protocol if protocol is not None else "" try: - await communicator.write(protocol) + await communicator.write(protocol_str) except MultiselectCommunicatorError as error: raise MultiselectClientError() from error @@ -145,7 +147,7 @@ class MultiselectClient(IMultiselectClient): except MultiselectCommunicatorError as error: raise MultiselectClientError() from error - if response == protocol: + if response == protocol_str: return protocol if response == PROTOCOL_NOT_FOUND_MSG: raise MultiselectClientError("protocol not supported") diff --git a/libp2p/protocol_muxer/multiselect_communicator.py b/libp2p/protocol_muxer/multiselect_communicator.py index c52266fd..98a8129c 100644 --- a/libp2p/protocol_muxer/multiselect_communicator.py +++ b/libp2p/protocol_muxer/multiselect_communicator.py @@ -30,7 +30,10 @@ class MultiselectCommunicator(IMultiselectCommunicator): """ :raise MultiselectCommunicatorError: raised when failed to write to underlying reader """ # noqa: E501 - msg_bytes = encode_delim(msg_str.encode()) + if msg_str is None: + msg_bytes = encode_delim(b"") + else: + msg_bytes = encode_delim(msg_str.encode()) try: await self.read_writer.write(msg_bytes) except IOException as error: diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index 193cc092..a9c4b19c 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -17,6 +17,9 @@ from libp2p.custom_types import ( from libp2p.peer.id import ( ID, ) +from libp2p.protocol_muxer.exceptions import ( + MultiselectError, +) from libp2p.protocol_muxer.multiselect import ( Multiselect, ) @@ -104,7 +107,7 @@ class SecurityMultistream(ABC): :param is_initiator: true if we are the initiator, false otherwise :return: selected secure transport """ - protocol: TProtocol + protocol: TProtocol | None communicator = MultiselectCommunicator(conn) if is_initiator: # Select protocol if initiator @@ -114,5 +117,7 @@ class SecurityMultistream(ABC): else: # Select protocol if non-initiator protocol, _ = await self.multiselect.negotiate(communicator) + if protocol is None: + raise MultiselectError("fail to negotiate a security protocol") # Return transport from protocol return self.transports[protocol] diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 76699c67..322db912 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -17,6 +17,9 @@ from libp2p.custom_types import ( from libp2p.peer.id import ( ID, ) +from libp2p.protocol_muxer.exceptions import ( + MultiselectError, +) from libp2p.protocol_muxer.multiselect import ( Multiselect, ) @@ -73,7 +76,7 @@ class MuxerMultistream: :param conn: conn to choose a transport over :return: selected muxer transport """ - protocol: TProtocol + protocol: TProtocol | None communicator = MultiselectCommunicator(conn) if conn.is_initiator: protocol = await self.multiselect_client.select_one_of( @@ -81,6 +84,8 @@ class MuxerMultistream: ) else: protocol, _ = await self.multiselect.negotiate(communicator) + if protocol is None: + raise MultiselectError("fail to negotiate a stream muxer protocol") return self.transports[protocol] async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn: diff --git a/newsfragments/770.internal.rst b/newsfragments/770.internal.rst new file mode 100644 index 00000000..f33cb3c0 --- /dev/null +++ b/newsfragments/770.internal.rst @@ -0,0 +1 @@ +Make TProtocol as Optional[TProtocol] to keep types consistent in py-libp2p/libp2p/protocol_muxer/multiselect.py diff --git a/tests/core/protocol_muxer/test_negotiate_timeout.py b/tests/core/protocol_muxer/test_negotiate_timeout.py index a50d65f6..1d089949 100644 --- a/tests/core/protocol_muxer/test_negotiate_timeout.py +++ b/tests/core/protocol_muxer/test_negotiate_timeout.py @@ -1,9 +1,9 @@ +from collections import deque + import pytest import trio -from libp2p.abc import ( - IMultiselectCommunicator, -) +from libp2p.abc import IMultiselectCommunicator, INetStream from libp2p.custom_types import TProtocol from libp2p.protocol_muxer.exceptions import ( MultiselectClientError, @@ -13,6 +13,10 @@ from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect_client import MultiselectClient +async def dummy_handler(stream: INetStream) -> None: + pass + + class DummyMultiselectCommunicator(IMultiselectCommunicator): """ Dummy MultiSelectCommunicator to test out negotiate timmeout. @@ -31,7 +35,7 @@ class DummyMultiselectCommunicator(IMultiselectCommunicator): @pytest.mark.trio -async def test_select_one_of_timeout(): +async def test_select_one_of_timeout() -> None: ECHO = TProtocol("/echo/1.0.0") communicator = DummyMultiselectCommunicator() @@ -42,7 +46,7 @@ async def test_select_one_of_timeout(): @pytest.mark.trio -async def test_query_multistream_command_timeout(): +async def test_query_multistream_command_timeout() -> None: communicator = DummyMultiselectCommunicator() client = MultiselectClient() @@ -51,9 +55,95 @@ async def test_query_multistream_command_timeout(): @pytest.mark.trio -async def test_negotiate_timeout(): +async def test_negotiate_timeout() -> None: communicator = DummyMultiselectCommunicator() server = Multiselect() with pytest.raises(MultiselectError, match="handshake read timeout"): await server.negotiate(communicator, 2) + + +class HandshakeThenHangCommunicator(IMultiselectCommunicator): + handshaked: bool + + def __init__(self) -> None: + self.handshaked = False + + async def write(self, msg_str: str) -> None: + if msg_str == "/multistream/1.0.0": + self.handshaked = True + return + + async def read(self) -> str: + if not self.handshaked: + return "/multistream/1.0.0" + # After handshake, hang on read. + await trio.sleep_forever() + # Should not be reached. + return "" + + +@pytest.mark.trio +async def test_negotiate_timeout_post_handshake() -> None: + communicator = HandshakeThenHangCommunicator() + server = Multiselect() + with pytest.raises(MultiselectError, match="handshake read timeout"): + await server.negotiate(communicator, 1) + + +class MockCommunicator(IMultiselectCommunicator): + def __init__(self, commands_to_read: list[str]): + self.read_queue = deque(commands_to_read) + self.written_data: list[str] = [] + + async def write(self, msg_str: str) -> None: + self.written_data.append(msg_str) + + async def read(self) -> str: + if not self.read_queue: + raise EOFError + return self.read_queue.popleft() + + +@pytest.mark.trio +async def test_negotiate_empty_string_command() -> None: + # server receives an empty string, which means client wants `None` protocol. + server = Multiselect({None: dummy_handler}) + # Handshake, then empty command + communicator = MockCommunicator(["/multistream/1.0.0", ""]) + protocol, handler = await server.negotiate(communicator) + assert protocol is None + assert handler == dummy_handler + # Check that server sent back handshake and the protocol confirmation (empty string) + assert communicator.written_data == ["/multistream/1.0.0", ""] + + +@pytest.mark.trio +async def test_negotiate_with_none_handler() -> None: + # server has None handler, client sends "" to select it. + server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler}) + # Handshake, then empty command + communicator = MockCommunicator(["/multistream/1.0.0", ""]) + protocol, handler = await server.negotiate(communicator) + assert protocol is None + assert handler == dummy_handler + # Check written data: handshake, protocol confirmation + assert communicator.written_data == ["/multistream/1.0.0", ""] + + +@pytest.mark.trio +async def test_negotiate_with_none_handler_ls() -> None: + # server has None handler, client sends "ls" then empty string. + server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler}) + # Handshake, ls, empty command + communicator = MockCommunicator(["/multistream/1.0.0", "ls", ""]) + protocol, handler = await server.negotiate(communicator) + assert protocol is None + assert handler == dummy_handler + # Check written data: handshake, ls response, protocol confirmation + assert communicator.written_data[0] == "/multistream/1.0.0" + assert "/proto1" in communicator.written_data[1] + # Note: `ls` should not list the `None` protocol. + assert "None" not in communicator.written_data[1] + assert "\n\n" not in communicator.written_data[1] + assert communicator.written_data[2] == "" diff --git a/tests/core/protocol_muxer/test_protocol_muxer.py b/tests/core/protocol_muxer/test_protocol_muxer.py index 1d6a0f86..57939bb6 100644 --- a/tests/core/protocol_muxer/test_protocol_muxer.py +++ b/tests/core/protocol_muxer/test_protocol_muxer.py @@ -159,3 +159,41 @@ async def test_get_protocols_returns_all_registered_protocols(): protocols = ms.get_protocols() assert set(protocols) == {p1, p2, p3} + + +@pytest.mark.trio +async def test_negotiate_optional_tprotocol(security_protocol): + with pytest.raises(Exception): + await perform_simple_test( + None, + [None], + [None], + security_protocol, + ) + + +@pytest.mark.trio +async def test_negotiate_optional_tprotocol_client_none_server_no_none( + security_protocol, +): + with pytest.raises(Exception): + await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol) + + +@pytest.mark.trio +async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol): + expected_selected_protocol = PROTOCOL_ECHO + await perform_simple_test( + expected_selected_protocol, + [None, PROTOCOL_ECHO], + [PROTOCOL_ECHO], + security_protocol, + ) + + +@pytest.mark.trio +async def test_negotiate_optional_tprotocol_server_none_client_other( + security_protocol, +): + with pytest.raises(Exception): + await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol)