mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'main' into add-ws-transport
This commit is contained in:
@ -213,7 +213,6 @@ class BasicHost(IHost):
|
||||
self,
|
||||
peer_id: ID,
|
||||
protocol_ids: Sequence[TProtocol],
|
||||
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> INetStream:
|
||||
"""
|
||||
:param peer_id: peer_id that host is connecting
|
||||
@ -227,7 +226,7 @@ class BasicHost(IHost):
|
||||
selected_protocol = await self.multiselect_client.select_one_of(
|
||||
list(protocol_ids),
|
||||
MultiselectCommunicator(net_stream),
|
||||
negotitate_timeout,
|
||||
self.negotiate_timeout,
|
||||
)
|
||||
except MultiselectClientError as error:
|
||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||
|
||||
@ -21,6 +21,7 @@ from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
Multiselect,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect_client import (
|
||||
@ -46,11 +47,17 @@ class MuxerMultistream:
|
||||
transports: "OrderedDict[TProtocol, TMuxerClass]"
|
||||
multiselect: Multiselect
|
||||
multiselect_client: MultiselectClient
|
||||
negotiate_timeout: int
|
||||
|
||||
def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
muxer_transports_by_protocol: TMuxerOptions,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> None:
|
||||
self.transports = OrderedDict()
|
||||
self.multiselect = Multiselect()
|
||||
self.multistream_client = MultiselectClient()
|
||||
self.negotiate_timeout = negotiate_timeout
|
||||
for protocol, transport in muxer_transports_by_protocol.items():
|
||||
self.add_transport(protocol, transport)
|
||||
|
||||
@ -80,10 +87,12 @@ class MuxerMultistream:
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if conn.is_initiator:
|
||||
protocol = await self.multiselect_client.select_one_of(
|
||||
tuple(self.transports.keys()), communicator
|
||||
tuple(self.transports.keys()), communicator, self.negotiate_timeout
|
||||
)
|
||||
else:
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
protocol, _ = await self.multiselect.negotiate(
|
||||
communicator, self.negotiate_timeout
|
||||
)
|
||||
if protocol is None:
|
||||
raise MultiselectError(
|
||||
"Fail to negotiate a stream muxer protocol: no protocol selected"
|
||||
@ -93,7 +102,7 @@ class MuxerMultistream:
|
||||
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
protocol = await self.multistream_client.select_one_of(
|
||||
tuple(self.transports.keys()), communicator
|
||||
tuple(self.transports.keys()), communicator, self.negotiate_timeout
|
||||
)
|
||||
transport_class = self.transports[protocol]
|
||||
if protocol == PROTOCOL_ID:
|
||||
|
||||
@ -14,6 +14,9 @@ from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectClientError,
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
)
|
||||
from libp2p.security.exceptions import (
|
||||
HandshakeFailure,
|
||||
)
|
||||
@ -37,9 +40,12 @@ class TransportUpgrader:
|
||||
self,
|
||||
secure_transports_by_protocol: TSecurityOptions,
|
||||
muxer_transports_by_protocol: TMuxerOptions,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
):
|
||||
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
|
||||
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
|
||||
self.muxer_multistream = MuxerMultistream(
|
||||
muxer_transports_by_protocol, negotiate_timeout
|
||||
)
|
||||
|
||||
async def upgrade_security(
|
||||
self,
|
||||
|
||||
1
newsfragments/896.bugfix.rst
Normal file
1
newsfragments/896.bugfix.rst
Normal file
@ -0,0 +1 @@
|
||||
Exposed timeout method in muxer multistream and updated all the usage. Added testcases to verify that timeout value is passed correctly
|
||||
108
tests/core/stream_muxer/test_muxer_multistream.py
Normal file
108
tests/core/stream_muxer/test_muxer_multistream.py
Normal file
@ -0,0 +1,108 @@
|
||||
from unittest.mock import (
|
||||
AsyncMock,
|
||||
MagicMock,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.custom_types import (
|
||||
TMuxerClass,
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.stream_muxer.muxer_multistream import (
|
||||
MuxerMultistream,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_muxer_timeout_configuration():
|
||||
"""Test that muxer respects timeout configuration."""
|
||||
muxer = MuxerMultistream({}, negotiate_timeout=1)
|
||||
assert muxer.negotiate_timeout == 1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_select_transport_passes_timeout_to_multiselect():
|
||||
"""Test that timeout is passed to multiselect client in select_transport."""
|
||||
# Mock dependencies
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.is_initiator = False
|
||||
|
||||
# Mock MultiselectClient
|
||||
muxer = MuxerMultistream({}, negotiate_timeout=10)
|
||||
muxer.multiselect.negotiate = AsyncMock(return_value=("mock_protocol", None))
|
||||
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Call select_transport
|
||||
await muxer.select_transport(mock_conn)
|
||||
|
||||
# Verify that select_one_of was called with the correct timeout
|
||||
args, _ = muxer.multiselect.negotiate.call_args
|
||||
assert args[1] == 10
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_new_conn_passes_timeout_to_multistream_client():
|
||||
"""Test that timeout is passed to multistream client in new_conn."""
|
||||
# Mock dependencies
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.is_initiator = True
|
||||
mock_peer_id = ID(b"test_peer")
|
||||
mock_communicator = MagicMock()
|
||||
|
||||
# Mock MultistreamClient and transports
|
||||
muxer = MuxerMultistream({}, negotiate_timeout=30)
|
||||
muxer.multistream_client.select_one_of = AsyncMock(return_value="mock_protocol")
|
||||
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Call new_conn
|
||||
await muxer.new_conn(mock_conn, mock_peer_id)
|
||||
|
||||
# Verify that select_one_of was called with the correct timeout
|
||||
muxer.multistream_client.select_one_of(
|
||||
tuple(muxer.transports.keys()), mock_communicator, 30
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_select_transport_no_protocol_selected():
|
||||
"""
|
||||
Test that select_transport raises MultiselectError when no protocol is selected.
|
||||
"""
|
||||
# Mock dependencies
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.is_initiator = False
|
||||
|
||||
# Mock Multiselect to return None
|
||||
muxer = MuxerMultistream({}, negotiate_timeout=30)
|
||||
muxer.multiselect.negotiate = AsyncMock(return_value=(None, None))
|
||||
|
||||
# Expect MultiselectError to be raised
|
||||
with pytest.raises(MultiselectError, match="no protocol selected"):
|
||||
await muxer.select_transport(mock_conn)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_add_transport_updates_precedence():
|
||||
"""Test that adding a transport updates protocol precedence."""
|
||||
# Mock transport classes
|
||||
mock_transport1 = MagicMock(spec=TMuxerClass)
|
||||
mock_transport2 = MagicMock(spec=TMuxerClass)
|
||||
|
||||
# Initialize muxer and add transports
|
||||
muxer = MuxerMultistream({}, negotiate_timeout=30)
|
||||
muxer.add_transport(TProtocol("proto1"), mock_transport1)
|
||||
muxer.add_transport(TProtocol("proto2"), mock_transport2)
|
||||
|
||||
# Verify transport order
|
||||
assert list(muxer.transports.keys()) == ["proto1", "proto2"]
|
||||
|
||||
# Re-add proto1 to check if it moves to the end
|
||||
muxer.add_transport(TProtocol("proto1"), mock_transport1)
|
||||
assert list(muxer.transports.keys()) == ["proto2", "proto1"]
|
||||
27
tests/core/transport/test_upgrader.py
Normal file
27
tests/core/transport/test_upgrader.py
Normal file
@ -0,0 +1,27 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.custom_types import (
|
||||
TMuxerOptions,
|
||||
TSecurityOptions,
|
||||
)
|
||||
from libp2p.transport.upgrader import (
|
||||
TransportUpgrader,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_transport_upgrader_security_and_muxer_initialization():
|
||||
"""Test TransportUpgrader initializes security and muxer multistreams correctly."""
|
||||
secure_transports: TSecurityOptions = {}
|
||||
muxer_transports: TMuxerOptions = {}
|
||||
negotiate_timeout = 15
|
||||
|
||||
upgrader = TransportUpgrader(
|
||||
secure_transports, muxer_transports, negotiate_timeout=negotiate_timeout
|
||||
)
|
||||
|
||||
# Verify security multistream initialization
|
||||
assert upgrader.security_multistream.transports == secure_transports
|
||||
# Verify muxer multistream initialization and timeout
|
||||
assert upgrader.muxer_multistream.transports == muxer_transports
|
||||
assert upgrader.muxer_multistream.negotiate_timeout == negotiate_timeout
|
||||
Reference in New Issue
Block a user