diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index e370a3de..6b7eb1d3 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -213,7 +213,6 @@ class BasicHost(IHost): self, peer_id: ID, protocol_ids: Sequence[TProtocol], - negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, ) -> INetStream: """ :param peer_id: peer_id that host is connecting @@ -227,7 +226,7 @@ class BasicHost(IHost): selected_protocol = await self.multiselect_client.select_one_of( list(protocol_ids), MultiselectCommunicator(net_stream), - negotitate_timeout, + self.negotiate_timeout, ) except MultiselectClientError as error: logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error) diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index ef90fac0..2d206141 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -21,6 +21,7 @@ from libp2p.protocol_muxer.exceptions import ( MultiselectError, ) from libp2p.protocol_muxer.multiselect import ( + DEFAULT_NEGOTIATE_TIMEOUT, Multiselect, ) from libp2p.protocol_muxer.multiselect_client import ( @@ -46,11 +47,17 @@ class MuxerMultistream: transports: "OrderedDict[TProtocol, TMuxerClass]" multiselect: Multiselect multiselect_client: MultiselectClient + negotiate_timeout: int - def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None: + def __init__( + self, + muxer_transports_by_protocol: TMuxerOptions, + negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, + ) -> None: self.transports = OrderedDict() self.multiselect = Multiselect() self.multistream_client = MultiselectClient() + self.negotiate_timeout = negotiate_timeout for protocol, transport in muxer_transports_by_protocol.items(): self.add_transport(protocol, transport) @@ -80,10 +87,12 @@ class MuxerMultistream: communicator = MultiselectCommunicator(conn) if conn.is_initiator: protocol = await self.multiselect_client.select_one_of( - tuple(self.transports.keys()), communicator + tuple(self.transports.keys()), communicator, self.negotiate_timeout ) else: - protocol, _ = await self.multiselect.negotiate(communicator) + protocol, _ = await self.multiselect.negotiate( + communicator, self.negotiate_timeout + ) if protocol is None: raise MultiselectError( "Fail to negotiate a stream muxer protocol: no protocol selected" @@ -93,7 +102,7 @@ class MuxerMultistream: async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn: communicator = MultiselectCommunicator(conn) protocol = await self.multistream_client.select_one_of( - tuple(self.transports.keys()), communicator + tuple(self.transports.keys()), communicator, self.negotiate_timeout ) transport_class = self.transports[protocol] if protocol == PROTOCOL_ID: diff --git a/libp2p/transport/upgrader.py b/libp2p/transport/upgrader.py index 40ba5321..dad2ad72 100644 --- a/libp2p/transport/upgrader.py +++ b/libp2p/transport/upgrader.py @@ -14,6 +14,9 @@ from libp2p.protocol_muxer.exceptions import ( MultiselectClientError, MultiselectError, ) +from libp2p.protocol_muxer.multiselect import ( + DEFAULT_NEGOTIATE_TIMEOUT, +) from libp2p.security.exceptions import ( HandshakeFailure, ) @@ -37,9 +40,12 @@ class TransportUpgrader: self, secure_transports_by_protocol: TSecurityOptions, muxer_transports_by_protocol: TMuxerOptions, + negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, ): self.security_multistream = SecurityMultistream(secure_transports_by_protocol) - self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol) + self.muxer_multistream = MuxerMultistream( + muxer_transports_by_protocol, negotiate_timeout + ) async def upgrade_security( self, diff --git a/newsfragments/896.bugfix.rst b/newsfragments/896.bugfix.rst new file mode 100644 index 00000000..aaf338d4 --- /dev/null +++ b/newsfragments/896.bugfix.rst @@ -0,0 +1 @@ +Exposed timeout method in muxer multistream and updated all the usage. Added testcases to verify that timeout value is passed correctly diff --git a/tests/core/stream_muxer/test_muxer_multistream.py b/tests/core/stream_muxer/test_muxer_multistream.py new file mode 100644 index 00000000..070d47ae --- /dev/null +++ b/tests/core/stream_muxer/test_muxer_multistream.py @@ -0,0 +1,108 @@ +from unittest.mock import ( + AsyncMock, + MagicMock, +) + +import pytest + +from libp2p.custom_types import ( + TMuxerClass, + TProtocol, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.protocol_muxer.exceptions import ( + MultiselectError, +) +from libp2p.stream_muxer.muxer_multistream import ( + MuxerMultistream, +) + + +@pytest.mark.trio +async def test_muxer_timeout_configuration(): + """Test that muxer respects timeout configuration.""" + muxer = MuxerMultistream({}, negotiate_timeout=1) + assert muxer.negotiate_timeout == 1 + + +@pytest.mark.trio +async def test_select_transport_passes_timeout_to_multiselect(): + """Test that timeout is passed to multiselect client in select_transport.""" + # Mock dependencies + mock_conn = MagicMock() + mock_conn.is_initiator = False + + # Mock MultiselectClient + muxer = MuxerMultistream({}, negotiate_timeout=10) + muxer.multiselect.negotiate = AsyncMock(return_value=("mock_protocol", None)) + muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock()) + + # Call select_transport + await muxer.select_transport(mock_conn) + + # Verify that select_one_of was called with the correct timeout + args, _ = muxer.multiselect.negotiate.call_args + assert args[1] == 10 + + +@pytest.mark.trio +async def test_new_conn_passes_timeout_to_multistream_client(): + """Test that timeout is passed to multistream client in new_conn.""" + # Mock dependencies + mock_conn = MagicMock() + mock_conn.is_initiator = True + mock_peer_id = ID(b"test_peer") + mock_communicator = MagicMock() + + # Mock MultistreamClient and transports + muxer = MuxerMultistream({}, negotiate_timeout=30) + muxer.multistream_client.select_one_of = AsyncMock(return_value="mock_protocol") + muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock()) + + # Call new_conn + await muxer.new_conn(mock_conn, mock_peer_id) + + # Verify that select_one_of was called with the correct timeout + muxer.multistream_client.select_one_of( + tuple(muxer.transports.keys()), mock_communicator, 30 + ) + + +@pytest.mark.trio +async def test_select_transport_no_protocol_selected(): + """ + Test that select_transport raises MultiselectError when no protocol is selected. + """ + # Mock dependencies + mock_conn = MagicMock() + mock_conn.is_initiator = False + + # Mock Multiselect to return None + muxer = MuxerMultistream({}, negotiate_timeout=30) + muxer.multiselect.negotiate = AsyncMock(return_value=(None, None)) + + # Expect MultiselectError to be raised + with pytest.raises(MultiselectError, match="no protocol selected"): + await muxer.select_transport(mock_conn) + + +@pytest.mark.trio +async def test_add_transport_updates_precedence(): + """Test that adding a transport updates protocol precedence.""" + # Mock transport classes + mock_transport1 = MagicMock(spec=TMuxerClass) + mock_transport2 = MagicMock(spec=TMuxerClass) + + # Initialize muxer and add transports + muxer = MuxerMultistream({}, negotiate_timeout=30) + muxer.add_transport(TProtocol("proto1"), mock_transport1) + muxer.add_transport(TProtocol("proto2"), mock_transport2) + + # Verify transport order + assert list(muxer.transports.keys()) == ["proto1", "proto2"] + + # Re-add proto1 to check if it moves to the end + muxer.add_transport(TProtocol("proto1"), mock_transport1) + assert list(muxer.transports.keys()) == ["proto2", "proto1"] diff --git a/tests/core/transport/test_upgrader.py b/tests/core/transport/test_upgrader.py new file mode 100644 index 00000000..8535a039 --- /dev/null +++ b/tests/core/transport/test_upgrader.py @@ -0,0 +1,27 @@ +import pytest + +from libp2p.custom_types import ( + TMuxerOptions, + TSecurityOptions, +) +from libp2p.transport.upgrader import ( + TransportUpgrader, +) + + +@pytest.mark.trio +async def test_transport_upgrader_security_and_muxer_initialization(): + """Test TransportUpgrader initializes security and muxer multistreams correctly.""" + secure_transports: TSecurityOptions = {} + muxer_transports: TMuxerOptions = {} + negotiate_timeout = 15 + + upgrader = TransportUpgrader( + secure_transports, muxer_transports, negotiate_timeout=negotiate_timeout + ) + + # Verify security multistream initialization + assert upgrader.security_multistream.transports == secure_transports + # Verify muxer multistream initialization and timeout + assert upgrader.muxer_multistream.transports == muxer_transports + assert upgrader.muxer_multistream.negotiate_timeout == negotiate_timeout