Merge branch 'main' into add-ws-transport

This commit is contained in:
Manu Sheel Gupta
2025-09-16 01:04:16 +05:30
committed by GitHub
6 changed files with 157 additions and 7 deletions

View File

@ -213,7 +213,6 @@ class BasicHost(IHost):
self,
peer_id: ID,
protocol_ids: Sequence[TProtocol],
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> INetStream:
"""
:param peer_id: peer_id that host is connecting
@ -227,7 +226,7 @@ class BasicHost(IHost):
selected_protocol = await self.multiselect_client.select_one_of(
list(protocol_ids),
MultiselectCommunicator(net_stream),
negotitate_timeout,
self.negotiate_timeout,
)
except MultiselectClientError as error:
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)

View File

@ -21,6 +21,7 @@ from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
DEFAULT_NEGOTIATE_TIMEOUT,
Multiselect,
)
from libp2p.protocol_muxer.multiselect_client import (
@ -46,11 +47,17 @@ class MuxerMultistream:
transports: "OrderedDict[TProtocol, TMuxerClass]"
multiselect: Multiselect
multiselect_client: MultiselectClient
negotiate_timeout: int
def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None:
def __init__(
self,
muxer_transports_by_protocol: TMuxerOptions,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> None:
self.transports = OrderedDict()
self.multiselect = Multiselect()
self.multistream_client = MultiselectClient()
self.negotiate_timeout = negotiate_timeout
for protocol, transport in muxer_transports_by_protocol.items():
self.add_transport(protocol, transport)
@ -80,10 +87,12 @@ class MuxerMultistream:
communicator = MultiselectCommunicator(conn)
if conn.is_initiator:
protocol = await self.multiselect_client.select_one_of(
tuple(self.transports.keys()), communicator
tuple(self.transports.keys()), communicator, self.negotiate_timeout
)
else:
protocol, _ = await self.multiselect.negotiate(communicator)
protocol, _ = await self.multiselect.negotiate(
communicator, self.negotiate_timeout
)
if protocol is None:
raise MultiselectError(
"Fail to negotiate a stream muxer protocol: no protocol selected"
@ -93,7 +102,7 @@ class MuxerMultistream:
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
communicator = MultiselectCommunicator(conn)
protocol = await self.multistream_client.select_one_of(
tuple(self.transports.keys()), communicator
tuple(self.transports.keys()), communicator, self.negotiate_timeout
)
transport_class = self.transports[protocol]
if protocol == PROTOCOL_ID:

View File

@ -14,6 +14,9 @@ from libp2p.protocol_muxer.exceptions import (
MultiselectClientError,
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
DEFAULT_NEGOTIATE_TIMEOUT,
)
from libp2p.security.exceptions import (
HandshakeFailure,
)
@ -37,9 +40,12 @@ class TransportUpgrader:
self,
secure_transports_by_protocol: TSecurityOptions,
muxer_transports_by_protocol: TMuxerOptions,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
):
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
self.muxer_multistream = MuxerMultistream(
muxer_transports_by_protocol, negotiate_timeout
)
async def upgrade_security(
self,

View File

@ -0,0 +1 @@
Exposed timeout method in muxer multistream and updated all the usage. Added testcases to verify that timeout value is passed correctly

View File

@ -0,0 +1,108 @@
from unittest.mock import (
AsyncMock,
MagicMock,
)
import pytest
from libp2p.custom_types import (
TMuxerClass,
TProtocol,
)
from libp2p.peer.id import (
ID,
)
from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.stream_muxer.muxer_multistream import (
MuxerMultistream,
)
@pytest.mark.trio
async def test_muxer_timeout_configuration():
"""Test that muxer respects timeout configuration."""
muxer = MuxerMultistream({}, negotiate_timeout=1)
assert muxer.negotiate_timeout == 1
@pytest.mark.trio
async def test_select_transport_passes_timeout_to_multiselect():
"""Test that timeout is passed to multiselect client in select_transport."""
# Mock dependencies
mock_conn = MagicMock()
mock_conn.is_initiator = False
# Mock MultiselectClient
muxer = MuxerMultistream({}, negotiate_timeout=10)
muxer.multiselect.negotiate = AsyncMock(return_value=("mock_protocol", None))
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
# Call select_transport
await muxer.select_transport(mock_conn)
# Verify that select_one_of was called with the correct timeout
args, _ = muxer.multiselect.negotiate.call_args
assert args[1] == 10
@pytest.mark.trio
async def test_new_conn_passes_timeout_to_multistream_client():
"""Test that timeout is passed to multistream client in new_conn."""
# Mock dependencies
mock_conn = MagicMock()
mock_conn.is_initiator = True
mock_peer_id = ID(b"test_peer")
mock_communicator = MagicMock()
# Mock MultistreamClient and transports
muxer = MuxerMultistream({}, negotiate_timeout=30)
muxer.multistream_client.select_one_of = AsyncMock(return_value="mock_protocol")
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
# Call new_conn
await muxer.new_conn(mock_conn, mock_peer_id)
# Verify that select_one_of was called with the correct timeout
muxer.multistream_client.select_one_of(
tuple(muxer.transports.keys()), mock_communicator, 30
)
@pytest.mark.trio
async def test_select_transport_no_protocol_selected():
"""
Test that select_transport raises MultiselectError when no protocol is selected.
"""
# Mock dependencies
mock_conn = MagicMock()
mock_conn.is_initiator = False
# Mock Multiselect to return None
muxer = MuxerMultistream({}, negotiate_timeout=30)
muxer.multiselect.negotiate = AsyncMock(return_value=(None, None))
# Expect MultiselectError to be raised
with pytest.raises(MultiselectError, match="no protocol selected"):
await muxer.select_transport(mock_conn)
@pytest.mark.trio
async def test_add_transport_updates_precedence():
"""Test that adding a transport updates protocol precedence."""
# Mock transport classes
mock_transport1 = MagicMock(spec=TMuxerClass)
mock_transport2 = MagicMock(spec=TMuxerClass)
# Initialize muxer and add transports
muxer = MuxerMultistream({}, negotiate_timeout=30)
muxer.add_transport(TProtocol("proto1"), mock_transport1)
muxer.add_transport(TProtocol("proto2"), mock_transport2)
# Verify transport order
assert list(muxer.transports.keys()) == ["proto1", "proto2"]
# Re-add proto1 to check if it moves to the end
muxer.add_transport(TProtocol("proto1"), mock_transport1)
assert list(muxer.transports.keys()) == ["proto2", "proto1"]

View File

@ -0,0 +1,27 @@
import pytest
from libp2p.custom_types import (
TMuxerOptions,
TSecurityOptions,
)
from libp2p.transport.upgrader import (
TransportUpgrader,
)
@pytest.mark.trio
async def test_transport_upgrader_security_and_muxer_initialization():
"""Test TransportUpgrader initializes security and muxer multistreams correctly."""
secure_transports: TSecurityOptions = {}
muxer_transports: TMuxerOptions = {}
negotiate_timeout = 15
upgrader = TransportUpgrader(
secure_transports, muxer_transports, negotiate_timeout=negotiate_timeout
)
# Verify security multistream initialization
assert upgrader.security_multistream.transports == secure_transports
# Verify muxer multistream initialization and timeout
assert upgrader.muxer_multistream.transports == muxer_transports
assert upgrader.muxer_multistream.negotiate_timeout == negotiate_timeout