mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
241 lines
7.7 KiB
Python
241 lines
7.7 KiB
Python
from unittest.mock import (
|
|
AsyncMock,
|
|
patch,
|
|
)
|
|
|
|
import pytest
|
|
|
|
from libp2p.host.autonat.autonat import (
|
|
AUTONAT_PROTOCOL_ID,
|
|
AutoNATService,
|
|
AutoNATStatus,
|
|
)
|
|
from libp2p.host.autonat.pb.autonat_pb2 import (
|
|
DialRequest,
|
|
DialResponse,
|
|
Message,
|
|
PeerInfo,
|
|
Status,
|
|
Type,
|
|
)
|
|
from libp2p.network.stream.exceptions import (
|
|
StreamError,
|
|
)
|
|
from libp2p.network.stream.net_stream import (
|
|
NetStream,
|
|
)
|
|
from libp2p.peer.id import (
|
|
ID,
|
|
)
|
|
from tests.utils.factories import (
|
|
HostFactory,
|
|
)
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_autonat_service_initialization():
|
|
"""Test that the AutoNAT service initializes correctly."""
|
|
async with HostFactory.create_batch_and_listen(1) as hosts:
|
|
host = hosts[0]
|
|
service = AutoNATService(host)
|
|
|
|
assert service.status == AutoNATStatus.UNKNOWN
|
|
assert service.dial_results == {}
|
|
assert service.host == host
|
|
assert service.peerstore == host.get_peerstore()
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_autonat_status_getter():
|
|
"""Test that the AutoNAT status getter works correctly."""
|
|
async with HostFactory.create_batch_and_listen(1) as hosts:
|
|
host = hosts[0]
|
|
service = AutoNATService(host)
|
|
|
|
# Testing the initial status
|
|
assert service.get_status() == AutoNATStatus.UNKNOWN
|
|
|
|
# Testing the status changes
|
|
service.status = AutoNATStatus.PUBLIC
|
|
assert service.get_status() == AutoNATStatus.PUBLIC
|
|
|
|
service.status = AutoNATStatus.PRIVATE
|
|
assert service.get_status() == AutoNATStatus.PRIVATE
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_update_status():
|
|
"""Test that the AutoNAT status updates correctly based on dial results."""
|
|
async with HostFactory.create_batch_and_listen(1) as hosts:
|
|
host = hosts[0]
|
|
service = AutoNATService(host)
|
|
|
|
# No dial results should result in UNKNOWN status
|
|
service.update_status()
|
|
assert service.status == AutoNATStatus.UNKNOWN
|
|
|
|
# Less than 2 successful dials should result in PRIVATE status
|
|
service.dial_results = {
|
|
ID("peer1"): True,
|
|
ID("peer2"): False,
|
|
ID("peer3"): False,
|
|
}
|
|
service.update_status()
|
|
assert service.status == AutoNATStatus.PRIVATE
|
|
|
|
# 2 or more successful dials should result in PUBLIC status
|
|
service.dial_results = {
|
|
ID("peer1"): True,
|
|
ID("peer2"): True,
|
|
ID("peer3"): False,
|
|
}
|
|
service.update_status()
|
|
assert service.status == AutoNATStatus.PUBLIC
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_try_dial():
|
|
"""Test that the try_dial method works correctly."""
|
|
async with HostFactory.create_batch_and_listen(2) as hosts:
|
|
host1, host2 = hosts
|
|
service = AutoNATService(host1)
|
|
peer_id = host2.get_id()
|
|
|
|
# Test successful dial
|
|
with patch.object(
|
|
host1, "new_stream", new_callable=AsyncMock
|
|
) as mock_new_stream:
|
|
mock_stream = AsyncMock(spec=NetStream)
|
|
mock_new_stream.return_value = mock_stream
|
|
|
|
result = await service._try_dial(peer_id)
|
|
|
|
assert result is True
|
|
mock_new_stream.assert_called_once_with(peer_id, [AUTONAT_PROTOCOL_ID])
|
|
mock_stream.close.assert_called_once()
|
|
|
|
# Test failed dial
|
|
with patch.object(
|
|
host1, "new_stream", new_callable=AsyncMock
|
|
) as mock_new_stream:
|
|
mock_new_stream.side_effect = Exception("Connection failed")
|
|
|
|
result = await service._try_dial(peer_id)
|
|
|
|
assert result is False
|
|
mock_new_stream.assert_called_once_with(peer_id, [AUTONAT_PROTOCOL_ID])
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_handle_dial():
|
|
"""Test that the handle_dial method works correctly."""
|
|
async with HostFactory.create_batch_and_listen(2) as hosts:
|
|
host1, host2 = hosts
|
|
service = AutoNATService(host1)
|
|
peer_id = host2.get_id()
|
|
|
|
# Create a mock message with a peer to dial
|
|
message = Message()
|
|
message.type = Type.Value("DIAL")
|
|
peer_info = PeerInfo()
|
|
peer_info.id = peer_id.to_bytes()
|
|
peer_info.addrs.extend([b"/ip4/127.0.0.1/tcp/4001"])
|
|
message.dial.peers.append(peer_info)
|
|
|
|
# Mock the _try_dial method
|
|
with patch.object(
|
|
service, "_try_dial", new_callable=AsyncMock
|
|
) as mock_try_dial:
|
|
mock_try_dial.return_value = True
|
|
|
|
response = await service._handle_dial(message)
|
|
|
|
assert response.type == Type.Value("DIAL_RESPONSE")
|
|
assert response.dial_response.status == Status.OK
|
|
assert len(response.dial_response.peers) == 1
|
|
assert response.dial_response.peers[0].id == peer_id.to_bytes()
|
|
assert response.dial_response.peers[0].success is True
|
|
mock_try_dial.assert_called_once_with(peer_id)
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_handle_request():
|
|
"""Test that the handle_request method works correctly."""
|
|
async with HostFactory.create_batch_and_listen(1) as hosts:
|
|
host = hosts[0]
|
|
service = AutoNATService(host)
|
|
|
|
# Test handling a DIAL request
|
|
message = Message()
|
|
message.type = Type.DIAL
|
|
dial_request = DialRequest()
|
|
peer_info = PeerInfo()
|
|
dial_request.peers.append(peer_info)
|
|
message.dial.CopyFrom(dial_request)
|
|
|
|
with patch.object(
|
|
service, "_handle_dial", new_callable=AsyncMock
|
|
) as mock_handle_dial:
|
|
mock_handle_dial.return_value = Message()
|
|
|
|
response = await service._handle_request(message.SerializeToString())
|
|
|
|
mock_handle_dial.assert_called_once()
|
|
assert isinstance(response, Message)
|
|
|
|
# Test handling an unknown request type
|
|
message = Message()
|
|
message.type = Type.UNKNOWN
|
|
|
|
response = await service._handle_request(message.SerializeToString())
|
|
|
|
assert isinstance(response, Message)
|
|
assert response.type == Type.DIAL_RESPONSE
|
|
assert response.dial_response.status == Status.E_INTERNAL_ERROR
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_handle_stream():
|
|
"""Test that handle_stream correctly processes stream data."""
|
|
async with HostFactory.create_batch_and_listen(1) as hosts:
|
|
host = hosts[0]
|
|
autonat_service = AutoNATService(host)
|
|
|
|
# Create a mock stream
|
|
mock_stream = AsyncMock(spec=NetStream)
|
|
|
|
# Create a properly initialized request Message
|
|
request = Message()
|
|
request.type = Type.DIAL
|
|
dial_request = DialRequest()
|
|
peer_info = PeerInfo()
|
|
peer_info.id = b"peer_id"
|
|
peer_info.addrs.append(b"addr1")
|
|
dial_request.peers.append(peer_info)
|
|
request.dial.CopyFrom(dial_request)
|
|
|
|
# Create a properly initialized response Message
|
|
response = Message()
|
|
response.type = Type.DIAL_RESPONSE
|
|
dial_response = DialResponse()
|
|
dial_response.status = Status.OK
|
|
dial_response.peers.append(peer_info)
|
|
response.dial_response.CopyFrom(dial_response)
|
|
|
|
# Mock stream read/write and _handle_request
|
|
mock_stream.read.return_value = request.SerializeToString()
|
|
mock_stream.write.return_value = None
|
|
autonat_service._handle_request = AsyncMock(return_value=response)
|
|
|
|
# Test successful stream handling
|
|
await autonat_service.handle_stream(mock_stream)
|
|
mock_stream.read.assert_called_once()
|
|
mock_stream.write.assert_called_once_with(response.SerializeToString())
|
|
mock_stream.close.assert_called_once()
|
|
|
|
# Test stream error handling
|
|
mock_stream.reset_mock()
|
|
mock_stream.read.side_effect = StreamError("Stream error")
|
|
await autonat_service.handle_stream(mock_stream)
|
|
mock_stream.close.assert_called_once()
|