Files
py-libp2p/tests/core/relay/test_dcutr_protocol.py
Soham Bhoir cb11f076c8 feat/606-enable-nat-traversal-via-hole-punching (#668)
* feat: base implementation of dcutr for hole-punching

* chore: removed circuit-relay imports from __init__

* feat: implemented dcutr protocol

* added test suite with mock setup

* Fix pre-commit hook issues in DCUtR implementation

* usages of CONNECT_TYPE and SYNC_TYPE have been replaced with HolePunch.Type.CONNECT and HolePunch.Type.SYNC

* added unit tests for dcutr and nat module and

* added multiaddr.get_peer_id() with proper DNS address handling and fixed method signature inconsistencies

* added assertions to verify DCUtR hole punch result in integration test

---------

Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
2025-08-07 19:00:16 -06:00

209 lines
6.4 KiB
Python

"""Unit tests for DCUtR protocol."""
import logging
from unittest.mock import AsyncMock, MagicMock
import pytest
import trio
from libp2p.abc import INetStream
from libp2p.peer.id import ID
from libp2p.relay.circuit_v2.dcutr import (
MAX_HOLE_PUNCH_ATTEMPTS,
DCUtRProtocol,
)
from libp2p.relay.circuit_v2.pb.dcutr_pb2 import HolePunch
from libp2p.tools.async_service import background_trio_service
logger = logging.getLogger(__name__)
@pytest.mark.trio
async def test_dcutr_protocol_initialization():
"""Test DCUtR protocol initialization."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
# Test that the protocol is initialized correctly
assert dcutr.host == mock_host
assert not dcutr.event_started.is_set()
assert dcutr._hole_punch_attempts == {}
assert dcutr._direct_connections == set()
assert dcutr._in_progress == set()
# Test that the protocol can be started
async with background_trio_service(dcutr):
# Wait for the protocol to start
await dcutr.event_started.wait()
# Verify that the stream handler was registered
mock_host.set_stream_handler.assert_called_once()
# Verify that the event is set
assert dcutr.event_started.is_set()
@pytest.mark.trio
async def test_dcutr_message_exchange():
"""Test DCUtR message exchange."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
# Test that the protocol can be started
async with background_trio_service(dcutr):
# Wait for the protocol to start
await dcutr.event_started.wait()
# Test CONNECT message
connect_msg = HolePunch(
type=HolePunch.CONNECT,
ObsAddrs=[b"/ip4/127.0.0.1/tcp/1234", b"/ip4/192.168.1.1/tcp/5678"],
)
# Test SYNC message
sync_msg = HolePunch(type=HolePunch.SYNC)
# Verify message types
assert connect_msg.type == HolePunch.CONNECT
assert sync_msg.type == HolePunch.SYNC
assert len(connect_msg.ObsAddrs) == 2
@pytest.mark.trio
async def test_dcutr_error_handling(monkeypatch):
"""Test DCUtR error handling."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
async with background_trio_service(dcutr):
await dcutr.event_started.wait()
# Simulate a stream that times out
class TimeoutStream(INetStream):
def __init__(self):
self._protocol = None
self.muxed_conn = MagicMock(peer_id=ID(b"peer"))
async def read(self, n: int | None = None) -> bytes:
await trio.sleep(0.2)
raise trio.TooSlowError()
async def write(self, data: bytes) -> None:
return None
async def close(self, *args, **kwargs):
return None
async def reset(self):
return None
def get_protocol(self):
return self._protocol
def set_protocol(self, protocol_id):
self._protocol = protocol_id
def get_remote_address(self):
return ("127.0.0.1", 1234)
# Should not raise, just log and close
await dcutr._handle_dcutr_stream(TimeoutStream())
# Simulate a stream with malformed message
class MalformedStream(INetStream):
def __init__(self):
self._protocol = None
self.muxed_conn = MagicMock(peer_id=ID(b"peer"))
async def read(self, n: int | None = None) -> bytes:
return b"not-a-protobuf"
async def write(self, data: bytes) -> None:
return None
async def close(self, *args, **kwargs):
return None
async def reset(self):
return None
def get_protocol(self):
return self._protocol
def set_protocol(self, protocol_id):
self._protocol = protocol_id
def get_remote_address(self):
return ("127.0.0.1", 1234)
await dcutr._handle_dcutr_stream(MalformedStream())
@pytest.mark.trio
async def test_dcutr_max_attempts_and_already_connected():
"""Test max hole punch attempts and already-connected peer."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
peer_id = ID(b"peer")
# Simulate already having a direct connection
dcutr._direct_connections.add(peer_id)
result = await dcutr.initiate_hole_punch(peer_id)
assert result is True
# Remove direct connection, simulate max attempts
dcutr._direct_connections.clear()
dcutr._hole_punch_attempts[peer_id] = MAX_HOLE_PUNCH_ATTEMPTS
result = await dcutr.initiate_hole_punch(peer_id)
assert result is False
@pytest.mark.trio
async def test_dcutr_observed_addr_encoding_decoding():
"""Test observed address encoding/decoding."""
from multiaddr import Multiaddr
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
# Simulate valid and invalid multiaddrs as bytes
valid = [
Multiaddr("/ip4/127.0.0.1/tcp/1234").to_bytes(),
Multiaddr("/ip4/192.168.1.1/tcp/5678").to_bytes(),
]
invalid = [b"not-a-multiaddr", b""]
decoded = dcutr._decode_observed_addrs(valid + invalid)
assert len(decoded) == 2
@pytest.mark.trio
async def test_dcutr_real_perform_hole_punch(monkeypatch):
"""Test initiate_hole_punch with real _perform_hole_punch logic (mock network)."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
peer_id = ID(b"peer")
# Patch methods to simulate a successful punch
monkeypatch.setattr(dcutr, "_have_direct_connection", AsyncMock(return_value=False))
monkeypatch.setattr(
dcutr,
"_get_observed_addrs",
AsyncMock(return_value=[b"/ip4/127.0.0.1/tcp/1234"]),
)
mock_stream = MagicMock()
mock_stream.read = AsyncMock(
side_effect=[
HolePunch(
type=HolePunch.CONNECT, ObsAddrs=[b"/ip4/192.168.1.1/tcp/4321"]
).SerializeToString(),
HolePunch(type=HolePunch.SYNC).SerializeToString(),
]
)
mock_stream.write = AsyncMock()
mock_stream.close = AsyncMock()
mock_stream.muxed_conn = MagicMock(peer_id=peer_id)
mock_host.new_stream = AsyncMock(return_value=mock_stream)
monkeypatch.setattr(dcutr, "_perform_hole_punch", AsyncMock(return_value=True))
result = await dcutr.initiate_hole_punch(peer_id)
assert result is True