mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
* feat: base implementation of dcutr for hole-punching * chore: removed circuit-relay imports from __init__ * feat: implemented dcutr protocol * added test suite with mock setup * Fix pre-commit hook issues in DCUtR implementation * usages of CONNECT_TYPE and SYNC_TYPE have been replaced with HolePunch.Type.CONNECT and HolePunch.Type.SYNC * added unit tests for dcutr and nat module and * added multiaddr.get_peer_id() with proper DNS address handling and fixed method signature inconsistencies * added assertions to verify DCUtR hole punch result in integration test --------- Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
209 lines
6.4 KiB
Python
209 lines
6.4 KiB
Python
"""Unit tests for DCUtR protocol."""
|
|
|
|
import logging
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
import trio
|
|
|
|
from libp2p.abc import INetStream
|
|
from libp2p.peer.id import ID
|
|
from libp2p.relay.circuit_v2.dcutr import (
|
|
MAX_HOLE_PUNCH_ATTEMPTS,
|
|
DCUtRProtocol,
|
|
)
|
|
from libp2p.relay.circuit_v2.pb.dcutr_pb2 import HolePunch
|
|
from libp2p.tools.async_service import background_trio_service
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_dcutr_protocol_initialization():
|
|
"""Test DCUtR protocol initialization."""
|
|
mock_host = MagicMock()
|
|
dcutr = DCUtRProtocol(mock_host)
|
|
|
|
# Test that the protocol is initialized correctly
|
|
assert dcutr.host == mock_host
|
|
assert not dcutr.event_started.is_set()
|
|
assert dcutr._hole_punch_attempts == {}
|
|
assert dcutr._direct_connections == set()
|
|
assert dcutr._in_progress == set()
|
|
|
|
# Test that the protocol can be started
|
|
async with background_trio_service(dcutr):
|
|
# Wait for the protocol to start
|
|
await dcutr.event_started.wait()
|
|
|
|
# Verify that the stream handler was registered
|
|
mock_host.set_stream_handler.assert_called_once()
|
|
|
|
# Verify that the event is set
|
|
assert dcutr.event_started.is_set()
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_dcutr_message_exchange():
|
|
"""Test DCUtR message exchange."""
|
|
mock_host = MagicMock()
|
|
dcutr = DCUtRProtocol(mock_host)
|
|
|
|
# Test that the protocol can be started
|
|
async with background_trio_service(dcutr):
|
|
# Wait for the protocol to start
|
|
await dcutr.event_started.wait()
|
|
|
|
# Test CONNECT message
|
|
connect_msg = HolePunch(
|
|
type=HolePunch.CONNECT,
|
|
ObsAddrs=[b"/ip4/127.0.0.1/tcp/1234", b"/ip4/192.168.1.1/tcp/5678"],
|
|
)
|
|
|
|
# Test SYNC message
|
|
sync_msg = HolePunch(type=HolePunch.SYNC)
|
|
|
|
# Verify message types
|
|
assert connect_msg.type == HolePunch.CONNECT
|
|
assert sync_msg.type == HolePunch.SYNC
|
|
assert len(connect_msg.ObsAddrs) == 2
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_dcutr_error_handling(monkeypatch):
|
|
"""Test DCUtR error handling."""
|
|
mock_host = MagicMock()
|
|
dcutr = DCUtRProtocol(mock_host)
|
|
|
|
async with background_trio_service(dcutr):
|
|
await dcutr.event_started.wait()
|
|
|
|
# Simulate a stream that times out
|
|
class TimeoutStream(INetStream):
|
|
def __init__(self):
|
|
self._protocol = None
|
|
self.muxed_conn = MagicMock(peer_id=ID(b"peer"))
|
|
|
|
async def read(self, n: int | None = None) -> bytes:
|
|
await trio.sleep(0.2)
|
|
raise trio.TooSlowError()
|
|
|
|
async def write(self, data: bytes) -> None:
|
|
return None
|
|
|
|
async def close(self, *args, **kwargs):
|
|
return None
|
|
|
|
async def reset(self):
|
|
return None
|
|
|
|
def get_protocol(self):
|
|
return self._protocol
|
|
|
|
def set_protocol(self, protocol_id):
|
|
self._protocol = protocol_id
|
|
|
|
def get_remote_address(self):
|
|
return ("127.0.0.1", 1234)
|
|
|
|
# Should not raise, just log and close
|
|
await dcutr._handle_dcutr_stream(TimeoutStream())
|
|
|
|
# Simulate a stream with malformed message
|
|
class MalformedStream(INetStream):
|
|
def __init__(self):
|
|
self._protocol = None
|
|
self.muxed_conn = MagicMock(peer_id=ID(b"peer"))
|
|
|
|
async def read(self, n: int | None = None) -> bytes:
|
|
return b"not-a-protobuf"
|
|
|
|
async def write(self, data: bytes) -> None:
|
|
return None
|
|
|
|
async def close(self, *args, **kwargs):
|
|
return None
|
|
|
|
async def reset(self):
|
|
return None
|
|
|
|
def get_protocol(self):
|
|
return self._protocol
|
|
|
|
def set_protocol(self, protocol_id):
|
|
self._protocol = protocol_id
|
|
|
|
def get_remote_address(self):
|
|
return ("127.0.0.1", 1234)
|
|
|
|
await dcutr._handle_dcutr_stream(MalformedStream())
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_dcutr_max_attempts_and_already_connected():
|
|
"""Test max hole punch attempts and already-connected peer."""
|
|
mock_host = MagicMock()
|
|
dcutr = DCUtRProtocol(mock_host)
|
|
peer_id = ID(b"peer")
|
|
|
|
# Simulate already having a direct connection
|
|
dcutr._direct_connections.add(peer_id)
|
|
result = await dcutr.initiate_hole_punch(peer_id)
|
|
assert result is True
|
|
|
|
# Remove direct connection, simulate max attempts
|
|
dcutr._direct_connections.clear()
|
|
dcutr._hole_punch_attempts[peer_id] = MAX_HOLE_PUNCH_ATTEMPTS
|
|
result = await dcutr.initiate_hole_punch(peer_id)
|
|
assert result is False
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_dcutr_observed_addr_encoding_decoding():
|
|
"""Test observed address encoding/decoding."""
|
|
from multiaddr import Multiaddr
|
|
|
|
mock_host = MagicMock()
|
|
dcutr = DCUtRProtocol(mock_host)
|
|
# Simulate valid and invalid multiaddrs as bytes
|
|
valid = [
|
|
Multiaddr("/ip4/127.0.0.1/tcp/1234").to_bytes(),
|
|
Multiaddr("/ip4/192.168.1.1/tcp/5678").to_bytes(),
|
|
]
|
|
invalid = [b"not-a-multiaddr", b""]
|
|
decoded = dcutr._decode_observed_addrs(valid + invalid)
|
|
assert len(decoded) == 2
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_dcutr_real_perform_hole_punch(monkeypatch):
|
|
"""Test initiate_hole_punch with real _perform_hole_punch logic (mock network)."""
|
|
mock_host = MagicMock()
|
|
dcutr = DCUtRProtocol(mock_host)
|
|
peer_id = ID(b"peer")
|
|
|
|
# Patch methods to simulate a successful punch
|
|
monkeypatch.setattr(dcutr, "_have_direct_connection", AsyncMock(return_value=False))
|
|
monkeypatch.setattr(
|
|
dcutr,
|
|
"_get_observed_addrs",
|
|
AsyncMock(return_value=[b"/ip4/127.0.0.1/tcp/1234"]),
|
|
)
|
|
mock_stream = MagicMock()
|
|
mock_stream.read = AsyncMock(
|
|
side_effect=[
|
|
HolePunch(
|
|
type=HolePunch.CONNECT, ObsAddrs=[b"/ip4/192.168.1.1/tcp/4321"]
|
|
).SerializeToString(),
|
|
HolePunch(type=HolePunch.SYNC).SerializeToString(),
|
|
]
|
|
)
|
|
mock_stream.write = AsyncMock()
|
|
mock_stream.close = AsyncMock()
|
|
mock_stream.muxed_conn = MagicMock(peer_id=peer_id)
|
|
mock_host.new_stream = AsyncMock(return_value=mock_stream)
|
|
monkeypatch.setattr(dcutr, "_perform_hole_punch", AsyncMock(return_value=True))
|
|
|
|
result = await dcutr.initiate_hole_punch(peer_id)
|
|
assert result is True
|