feat/606-enable-nat-traversal-via-hole-punching (#668)

* feat: base implementation of dcutr for hole-punching

* chore: removed circuit-relay imports from __init__

* feat: implemented dcutr protocol

* added test suite with mock setup

* Fix pre-commit hook issues in DCUtR implementation

* usages of CONNECT_TYPE and SYNC_TYPE have been replaced with HolePunch.Type.CONNECT and HolePunch.Type.SYNC

* added unit tests for dcutr and nat module and

* added multiaddr.get_peer_id() with proper DNS address handling and fixed method signature inconsistencies

* added assertions to verify DCUtR hole punch result in integration test

---------

Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
This commit is contained in:
Soham Bhoir
2025-08-08 06:30:16 +05:30
committed by GitHub
parent 9ed44f5fa3
commit cb11f076c8
14 changed files with 2099 additions and 1 deletions

View File

@ -0,0 +1,208 @@
"""Unit tests for DCUtR protocol."""
import logging
from unittest.mock import AsyncMock, MagicMock
import pytest
import trio
from libp2p.abc import INetStream
from libp2p.peer.id import ID
from libp2p.relay.circuit_v2.dcutr import (
MAX_HOLE_PUNCH_ATTEMPTS,
DCUtRProtocol,
)
from libp2p.relay.circuit_v2.pb.dcutr_pb2 import HolePunch
from libp2p.tools.async_service import background_trio_service
logger = logging.getLogger(__name__)
@pytest.mark.trio
async def test_dcutr_protocol_initialization():
"""Test DCUtR protocol initialization."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
# Test that the protocol is initialized correctly
assert dcutr.host == mock_host
assert not dcutr.event_started.is_set()
assert dcutr._hole_punch_attempts == {}
assert dcutr._direct_connections == set()
assert dcutr._in_progress == set()
# Test that the protocol can be started
async with background_trio_service(dcutr):
# Wait for the protocol to start
await dcutr.event_started.wait()
# Verify that the stream handler was registered
mock_host.set_stream_handler.assert_called_once()
# Verify that the event is set
assert dcutr.event_started.is_set()
@pytest.mark.trio
async def test_dcutr_message_exchange():
"""Test DCUtR message exchange."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
# Test that the protocol can be started
async with background_trio_service(dcutr):
# Wait for the protocol to start
await dcutr.event_started.wait()
# Test CONNECT message
connect_msg = HolePunch(
type=HolePunch.CONNECT,
ObsAddrs=[b"/ip4/127.0.0.1/tcp/1234", b"/ip4/192.168.1.1/tcp/5678"],
)
# Test SYNC message
sync_msg = HolePunch(type=HolePunch.SYNC)
# Verify message types
assert connect_msg.type == HolePunch.CONNECT
assert sync_msg.type == HolePunch.SYNC
assert len(connect_msg.ObsAddrs) == 2
@pytest.mark.trio
async def test_dcutr_error_handling(monkeypatch):
"""Test DCUtR error handling."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
async with background_trio_service(dcutr):
await dcutr.event_started.wait()
# Simulate a stream that times out
class TimeoutStream(INetStream):
def __init__(self):
self._protocol = None
self.muxed_conn = MagicMock(peer_id=ID(b"peer"))
async def read(self, n: int | None = None) -> bytes:
await trio.sleep(0.2)
raise trio.TooSlowError()
async def write(self, data: bytes) -> None:
return None
async def close(self, *args, **kwargs):
return None
async def reset(self):
return None
def get_protocol(self):
return self._protocol
def set_protocol(self, protocol_id):
self._protocol = protocol_id
def get_remote_address(self):
return ("127.0.0.1", 1234)
# Should not raise, just log and close
await dcutr._handle_dcutr_stream(TimeoutStream())
# Simulate a stream with malformed message
class MalformedStream(INetStream):
def __init__(self):
self._protocol = None
self.muxed_conn = MagicMock(peer_id=ID(b"peer"))
async def read(self, n: int | None = None) -> bytes:
return b"not-a-protobuf"
async def write(self, data: bytes) -> None:
return None
async def close(self, *args, **kwargs):
return None
async def reset(self):
return None
def get_protocol(self):
return self._protocol
def set_protocol(self, protocol_id):
self._protocol = protocol_id
def get_remote_address(self):
return ("127.0.0.1", 1234)
await dcutr._handle_dcutr_stream(MalformedStream())
@pytest.mark.trio
async def test_dcutr_max_attempts_and_already_connected():
"""Test max hole punch attempts and already-connected peer."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
peer_id = ID(b"peer")
# Simulate already having a direct connection
dcutr._direct_connections.add(peer_id)
result = await dcutr.initiate_hole_punch(peer_id)
assert result is True
# Remove direct connection, simulate max attempts
dcutr._direct_connections.clear()
dcutr._hole_punch_attempts[peer_id] = MAX_HOLE_PUNCH_ATTEMPTS
result = await dcutr.initiate_hole_punch(peer_id)
assert result is False
@pytest.mark.trio
async def test_dcutr_observed_addr_encoding_decoding():
"""Test observed address encoding/decoding."""
from multiaddr import Multiaddr
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
# Simulate valid and invalid multiaddrs as bytes
valid = [
Multiaddr("/ip4/127.0.0.1/tcp/1234").to_bytes(),
Multiaddr("/ip4/192.168.1.1/tcp/5678").to_bytes(),
]
invalid = [b"not-a-multiaddr", b""]
decoded = dcutr._decode_observed_addrs(valid + invalid)
assert len(decoded) == 2
@pytest.mark.trio
async def test_dcutr_real_perform_hole_punch(monkeypatch):
"""Test initiate_hole_punch with real _perform_hole_punch logic (mock network)."""
mock_host = MagicMock()
dcutr = DCUtRProtocol(mock_host)
peer_id = ID(b"peer")
# Patch methods to simulate a successful punch
monkeypatch.setattr(dcutr, "_have_direct_connection", AsyncMock(return_value=False))
monkeypatch.setattr(
dcutr,
"_get_observed_addrs",
AsyncMock(return_value=[b"/ip4/127.0.0.1/tcp/1234"]),
)
mock_stream = MagicMock()
mock_stream.read = AsyncMock(
side_effect=[
HolePunch(
type=HolePunch.CONNECT, ObsAddrs=[b"/ip4/192.168.1.1/tcp/4321"]
).SerializeToString(),
HolePunch(type=HolePunch.SYNC).SerializeToString(),
]
)
mock_stream.write = AsyncMock()
mock_stream.close = AsyncMock()
mock_stream.muxed_conn = MagicMock(peer_id=peer_id)
mock_host.new_stream = AsyncMock(return_value=mock_stream)
monkeypatch.setattr(dcutr, "_perform_hole_punch", AsyncMock(return_value=True))
result = await dcutr.initiate_hole_punch(peer_id)
assert result is True