mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
Added timeout passing in muxermultistream. Updated the usages. Tested the params are passed correctly
This commit is contained in:
@ -213,7 +213,6 @@ class BasicHost(IHost):
|
|||||||
self,
|
self,
|
||||||
peer_id: ID,
|
peer_id: ID,
|
||||||
protocol_ids: Sequence[TProtocol],
|
protocol_ids: Sequence[TProtocol],
|
||||||
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
|
||||||
) -> INetStream:
|
) -> INetStream:
|
||||||
"""
|
"""
|
||||||
:param peer_id: peer_id that host is connecting
|
:param peer_id: peer_id that host is connecting
|
||||||
@ -227,7 +226,7 @@ class BasicHost(IHost):
|
|||||||
selected_protocol = await self.multiselect_client.select_one_of(
|
selected_protocol = await self.multiselect_client.select_one_of(
|
||||||
list(protocol_ids),
|
list(protocol_ids),
|
||||||
MultiselectCommunicator(net_stream),
|
MultiselectCommunicator(net_stream),
|
||||||
negotitate_timeout,
|
self.negotiate_timeout,
|
||||||
)
|
)
|
||||||
except MultiselectClientError as error:
|
except MultiselectClientError as error:
|
||||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from libp2p.protocol_muxer.exceptions import (
|
|||||||
MultiselectError,
|
MultiselectError,
|
||||||
)
|
)
|
||||||
from libp2p.protocol_muxer.multiselect import (
|
from libp2p.protocol_muxer.multiselect import (
|
||||||
|
DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
Multiselect,
|
Multiselect,
|
||||||
)
|
)
|
||||||
from libp2p.protocol_muxer.multiselect_client import (
|
from libp2p.protocol_muxer.multiselect_client import (
|
||||||
@ -46,11 +47,17 @@ class MuxerMultistream:
|
|||||||
transports: "OrderedDict[TProtocol, TMuxerClass]"
|
transports: "OrderedDict[TProtocol, TMuxerClass]"
|
||||||
multiselect: Multiselect
|
multiselect: Multiselect
|
||||||
multiselect_client: MultiselectClient
|
multiselect_client: MultiselectClient
|
||||||
|
negotiate_timeout: int
|
||||||
|
|
||||||
def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
muxer_transports_by_protocol: TMuxerOptions,
|
||||||
|
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
|
) -> None:
|
||||||
self.transports = OrderedDict()
|
self.transports = OrderedDict()
|
||||||
self.multiselect = Multiselect()
|
self.multiselect = Multiselect()
|
||||||
self.multistream_client = MultiselectClient()
|
self.multistream_client = MultiselectClient()
|
||||||
|
self.negotiate_timeout = negotiate_timeout
|
||||||
for protocol, transport in muxer_transports_by_protocol.items():
|
for protocol, transport in muxer_transports_by_protocol.items():
|
||||||
self.add_transport(protocol, transport)
|
self.add_transport(protocol, transport)
|
||||||
|
|
||||||
@ -80,10 +87,12 @@ class MuxerMultistream:
|
|||||||
communicator = MultiselectCommunicator(conn)
|
communicator = MultiselectCommunicator(conn)
|
||||||
if conn.is_initiator:
|
if conn.is_initiator:
|
||||||
protocol = await self.multiselect_client.select_one_of(
|
protocol = await self.multiselect_client.select_one_of(
|
||||||
tuple(self.transports.keys()), communicator
|
tuple(self.transports.keys()), communicator, self.negotiate_timeout
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
protocol, _ = await self.multiselect.negotiate(
|
||||||
|
communicator, self.negotiate_timeout
|
||||||
|
)
|
||||||
if protocol is None:
|
if protocol is None:
|
||||||
raise MultiselectError(
|
raise MultiselectError(
|
||||||
"Fail to negotiate a stream muxer protocol: no protocol selected"
|
"Fail to negotiate a stream muxer protocol: no protocol selected"
|
||||||
@ -93,7 +102,7 @@ class MuxerMultistream:
|
|||||||
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
||||||
communicator = MultiselectCommunicator(conn)
|
communicator = MultiselectCommunicator(conn)
|
||||||
protocol = await self.multistream_client.select_one_of(
|
protocol = await self.multistream_client.select_one_of(
|
||||||
tuple(self.transports.keys()), communicator
|
tuple(self.transports.keys()), communicator, self.negotiate_timeout
|
||||||
)
|
)
|
||||||
transport_class = self.transports[protocol]
|
transport_class = self.transports[protocol]
|
||||||
if protocol == PROTOCOL_ID:
|
if protocol == PROTOCOL_ID:
|
||||||
|
|||||||
@ -14,6 +14,9 @@ from libp2p.protocol_muxer.exceptions import (
|
|||||||
MultiselectClientError,
|
MultiselectClientError,
|
||||||
MultiselectError,
|
MultiselectError,
|
||||||
)
|
)
|
||||||
|
from libp2p.protocol_muxer.multiselect import (
|
||||||
|
DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
|
)
|
||||||
from libp2p.security.exceptions import (
|
from libp2p.security.exceptions import (
|
||||||
HandshakeFailure,
|
HandshakeFailure,
|
||||||
)
|
)
|
||||||
@ -37,9 +40,12 @@ class TransportUpgrader:
|
|||||||
self,
|
self,
|
||||||
secure_transports_by_protocol: TSecurityOptions,
|
secure_transports_by_protocol: TSecurityOptions,
|
||||||
muxer_transports_by_protocol: TMuxerOptions,
|
muxer_transports_by_protocol: TMuxerOptions,
|
||||||
|
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
):
|
):
|
||||||
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
|
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
|
||||||
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
|
self.muxer_multistream = MuxerMultistream(
|
||||||
|
muxer_transports_by_protocol, negotiate_timeout
|
||||||
|
)
|
||||||
|
|
||||||
async def upgrade_security(
|
async def upgrade_security(
|
||||||
self,
|
self,
|
||||||
|
|||||||
1
newsfragments/896.bugfix.rst
Normal file
1
newsfragments/896.bugfix.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Exposed timeout method in muxer multistream and updated all the usage. Added testcases to verify that timeout value is passed correctly
|
||||||
108
tests/core/stream_muxer/test_muxer_multistream.py
Normal file
108
tests/core/stream_muxer/test_muxer_multistream.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
from unittest.mock import (
|
||||||
|
AsyncMock,
|
||||||
|
MagicMock,
|
||||||
|
)
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.custom_types import (
|
||||||
|
TMuxerClass,
|
||||||
|
TProtocol,
|
||||||
|
)
|
||||||
|
from libp2p.peer.id import (
|
||||||
|
ID,
|
||||||
|
)
|
||||||
|
from libp2p.protocol_muxer.exceptions import (
|
||||||
|
MultiselectError,
|
||||||
|
)
|
||||||
|
from libp2p.stream_muxer.muxer_multistream import (
|
||||||
|
MuxerMultistream,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_muxer_timeout_configuration():
|
||||||
|
"""Test that muxer respects timeout configuration."""
|
||||||
|
muxer = MuxerMultistream({}, negotiate_timeout=1)
|
||||||
|
assert muxer.negotiate_timeout == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_select_transport_passes_timeout_to_multiselect():
|
||||||
|
"""Test that timeout is passed to multiselect client in select_transport."""
|
||||||
|
# Mock dependencies
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.is_initiator = False
|
||||||
|
|
||||||
|
# Mock MultiselectClient
|
||||||
|
muxer = MuxerMultistream({}, negotiate_timeout=10)
|
||||||
|
muxer.multiselect.negotiate = AsyncMock(return_value=("mock_protocol", None))
|
||||||
|
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
# Call select_transport
|
||||||
|
await muxer.select_transport(mock_conn)
|
||||||
|
|
||||||
|
# Verify that select_one_of was called with the correct timeout
|
||||||
|
args, _ = muxer.multiselect.negotiate.call_args
|
||||||
|
assert args[1] == 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_new_conn_passes_timeout_to_multistream_client():
|
||||||
|
"""Test that timeout is passed to multistream client in new_conn."""
|
||||||
|
# Mock dependencies
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.is_initiator = True
|
||||||
|
mock_peer_id = ID(b"test_peer")
|
||||||
|
mock_communicator = MagicMock()
|
||||||
|
|
||||||
|
# Mock MultistreamClient and transports
|
||||||
|
muxer = MuxerMultistream({}, negotiate_timeout=30)
|
||||||
|
muxer.multistream_client.select_one_of = AsyncMock(return_value="mock_protocol")
|
||||||
|
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
# Call new_conn
|
||||||
|
await muxer.new_conn(mock_conn, mock_peer_id)
|
||||||
|
|
||||||
|
# Verify that select_one_of was called with the correct timeout
|
||||||
|
muxer.multistream_client.select_one_of(
|
||||||
|
tuple(muxer.transports.keys()), mock_communicator, 30
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_select_transport_no_protocol_selected():
|
||||||
|
"""
|
||||||
|
Test that select_transport raises MultiselectError when no protocol is selected.
|
||||||
|
"""
|
||||||
|
# Mock dependencies
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.is_initiator = False
|
||||||
|
|
||||||
|
# Mock Multiselect to return None
|
||||||
|
muxer = MuxerMultistream({}, negotiate_timeout=30)
|
||||||
|
muxer.multiselect.negotiate = AsyncMock(return_value=(None, None))
|
||||||
|
|
||||||
|
# Expect MultiselectError to be raised
|
||||||
|
with pytest.raises(MultiselectError, match="no protocol selected"):
|
||||||
|
await muxer.select_transport(mock_conn)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_add_transport_updates_precedence():
|
||||||
|
"""Test that adding a transport updates protocol precedence."""
|
||||||
|
# Mock transport classes
|
||||||
|
mock_transport1 = MagicMock(spec=TMuxerClass)
|
||||||
|
mock_transport2 = MagicMock(spec=TMuxerClass)
|
||||||
|
|
||||||
|
# Initialize muxer and add transports
|
||||||
|
muxer = MuxerMultistream({}, negotiate_timeout=30)
|
||||||
|
muxer.add_transport(TProtocol("proto1"), mock_transport1)
|
||||||
|
muxer.add_transport(TProtocol("proto2"), mock_transport2)
|
||||||
|
|
||||||
|
# Verify transport order
|
||||||
|
assert list(muxer.transports.keys()) == ["proto1", "proto2"]
|
||||||
|
|
||||||
|
# Re-add proto1 to check if it moves to the end
|
||||||
|
muxer.add_transport(TProtocol("proto1"), mock_transport1)
|
||||||
|
assert list(muxer.transports.keys()) == ["proto2", "proto1"]
|
||||||
27
tests/core/transport/test_upgrader.py
Normal file
27
tests/core/transport/test_upgrader.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.custom_types import (
|
||||||
|
TMuxerOptions,
|
||||||
|
TSecurityOptions,
|
||||||
|
)
|
||||||
|
from libp2p.transport.upgrader import (
|
||||||
|
TransportUpgrader,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_transport_upgrader_security_and_muxer_initialization():
|
||||||
|
"""Test TransportUpgrader initializes security and muxer multistreams correctly."""
|
||||||
|
secure_transports: TSecurityOptions = {}
|
||||||
|
muxer_transports: TMuxerOptions = {}
|
||||||
|
negotiate_timeout = 15
|
||||||
|
|
||||||
|
upgrader = TransportUpgrader(
|
||||||
|
secure_transports, muxer_transports, negotiate_timeout=negotiate_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify security multistream initialization
|
||||||
|
assert upgrader.security_multistream.transports == secure_transports
|
||||||
|
# Verify muxer multistream initialization and timeout
|
||||||
|
assert upgrader.muxer_multistream.transports == muxer_transports
|
||||||
|
assert upgrader.muxer_multistream.negotiate_timeout == negotiate_timeout
|
||||||
Reference in New Issue
Block a user