Files
py-libp2p/tests/core/host/test_autonat.py
2025-05-09 17:31:15 -06:00

241 lines
7.7 KiB
Python

from unittest.mock import (
AsyncMock,
patch,
)
import pytest
from libp2p.host.autonat.autonat import (
AUTONAT_PROTOCOL_ID,
AutoNATService,
AutoNATStatus,
)
from libp2p.host.autonat.pb.autonat_pb2 import (
DialRequest,
DialResponse,
Message,
PeerInfo,
Status,
Type,
)
from libp2p.network.stream.exceptions import (
StreamError,
)
from libp2p.network.stream.net_stream import (
NetStream,
)
from libp2p.peer.id import (
ID,
)
from tests.utils.factories import (
HostFactory,
)
@pytest.mark.trio
async def test_autonat_service_initialization():
"""Test that the AutoNAT service initializes correctly."""
async with HostFactory.create_batch_and_listen(1) as hosts:
host = hosts[0]
service = AutoNATService(host)
assert service.status == AutoNATStatus.UNKNOWN
assert service.dial_results == {}
assert service.host == host
assert service.peerstore == host.get_peerstore()
@pytest.mark.trio
async def test_autonat_status_getter():
"""Test that the AutoNAT status getter works correctly."""
async with HostFactory.create_batch_and_listen(1) as hosts:
host = hosts[0]
service = AutoNATService(host)
# Testing the initial status
assert service.get_status() == AutoNATStatus.UNKNOWN
# Testing the status changes
service.status = AutoNATStatus.PUBLIC
assert service.get_status() == AutoNATStatus.PUBLIC
service.status = AutoNATStatus.PRIVATE
assert service.get_status() == AutoNATStatus.PRIVATE
@pytest.mark.trio
async def test_update_status():
"""Test that the AutoNAT status updates correctly based on dial results."""
async with HostFactory.create_batch_and_listen(1) as hosts:
host = hosts[0]
service = AutoNATService(host)
# No dial results should result in UNKNOWN status
service.update_status()
assert service.status == AutoNATStatus.UNKNOWN
# Less than 2 successful dials should result in PRIVATE status
service.dial_results = {
ID("peer1"): True,
ID("peer2"): False,
ID("peer3"): False,
}
service.update_status()
assert service.status == AutoNATStatus.PRIVATE
# 2 or more successful dials should result in PUBLIC status
service.dial_results = {
ID("peer1"): True,
ID("peer2"): True,
ID("peer3"): False,
}
service.update_status()
assert service.status == AutoNATStatus.PUBLIC
@pytest.mark.trio
async def test_try_dial():
"""Test that the try_dial method works correctly."""
async with HostFactory.create_batch_and_listen(2) as hosts:
host1, host2 = hosts
service = AutoNATService(host1)
peer_id = host2.get_id()
# Test successful dial
with patch.object(
host1, "new_stream", new_callable=AsyncMock
) as mock_new_stream:
mock_stream = AsyncMock(spec=NetStream)
mock_new_stream.return_value = mock_stream
result = await service._try_dial(peer_id)
assert result is True
mock_new_stream.assert_called_once_with(peer_id, [AUTONAT_PROTOCOL_ID])
mock_stream.close.assert_called_once()
# Test failed dial
with patch.object(
host1, "new_stream", new_callable=AsyncMock
) as mock_new_stream:
mock_new_stream.side_effect = Exception("Connection failed")
result = await service._try_dial(peer_id)
assert result is False
mock_new_stream.assert_called_once_with(peer_id, [AUTONAT_PROTOCOL_ID])
@pytest.mark.trio
async def test_handle_dial():
"""Test that the handle_dial method works correctly."""
async with HostFactory.create_batch_and_listen(2) as hosts:
host1, host2 = hosts
service = AutoNATService(host1)
peer_id = host2.get_id()
# Create a mock message with a peer to dial
message = Message()
message.type = Type.Value("DIAL")
peer_info = PeerInfo()
peer_info.id = peer_id.to_bytes()
peer_info.addrs.extend([b"/ip4/127.0.0.1/tcp/4001"])
message.dial.peers.append(peer_info)
# Mock the _try_dial method
with patch.object(
service, "_try_dial", new_callable=AsyncMock
) as mock_try_dial:
mock_try_dial.return_value = True
response = await service._handle_dial(message)
assert response.type == Type.Value("DIAL_RESPONSE")
assert response.dial_response.status == Status.OK
assert len(response.dial_response.peers) == 1
assert response.dial_response.peers[0].id == peer_id.to_bytes()
assert response.dial_response.peers[0].success is True
mock_try_dial.assert_called_once_with(peer_id)
@pytest.mark.trio
async def test_handle_request():
"""Test that the handle_request method works correctly."""
async with HostFactory.create_batch_and_listen(1) as hosts:
host = hosts[0]
service = AutoNATService(host)
# Test handling a DIAL request
message = Message()
message.type = Type.DIAL
dial_request = DialRequest()
peer_info = PeerInfo()
dial_request.peers.append(peer_info)
message.dial.CopyFrom(dial_request)
with patch.object(
service, "_handle_dial", new_callable=AsyncMock
) as mock_handle_dial:
mock_handle_dial.return_value = Message()
response = await service._handle_request(message.SerializeToString())
mock_handle_dial.assert_called_once()
assert isinstance(response, Message)
# Test handling an unknown request type
message = Message()
message.type = Type.UNKNOWN
response = await service._handle_request(message.SerializeToString())
assert isinstance(response, Message)
assert response.type == Type.DIAL_RESPONSE
assert response.dial_response.status == Status.E_INTERNAL_ERROR
@pytest.mark.trio
async def test_handle_stream():
"""Test that handle_stream correctly processes stream data."""
async with HostFactory.create_batch_and_listen(1) as hosts:
host = hosts[0]
autonat_service = AutoNATService(host)
# Create a mock stream
mock_stream = AsyncMock(spec=NetStream)
# Create a properly initialized request Message
request = Message()
request.type = Type.DIAL
dial_request = DialRequest()
peer_info = PeerInfo()
peer_info.id = b"peer_id"
peer_info.addrs.append(b"addr1")
dial_request.peers.append(peer_info)
request.dial.CopyFrom(dial_request)
# Create a properly initialized response Message
response = Message()
response.type = Type.DIAL_RESPONSE
dial_response = DialResponse()
dial_response.status = Status.OK
dial_response.peers.append(peer_info)
response.dial_response.CopyFrom(dial_response)
# Mock stream read/write and _handle_request
mock_stream.read.return_value = request.SerializeToString()
mock_stream.write.return_value = None
autonat_service._handle_request = AsyncMock(return_value=response)
# Test successful stream handling
await autonat_service.handle_stream(mock_stream)
mock_stream.read.assert_called_once()
mock_stream.write.assert_called_once_with(response.SerializeToString())
mock_stream.close.assert_called_once()
# Test stream error handling
mock_stream.reset_mock()
mock_stream.read.side_effect = StreamError("Stream error")
await autonat_service.handle_stream(mock_stream)
mock_stream.close.assert_called_once()