diff --git a/libp2p/host/autonat/__init__.py b/libp2p/host/autonat/__init__.py new file mode 100644 index 00000000..56ed5aab --- /dev/null +++ b/libp2p/host/autonat/__init__.py @@ -0,0 +1,5 @@ +from .autonat import ( + AutoNATService, +) + +__all__ = ["AutoNATService"] diff --git a/libp2p/host/autonat/autonat.py b/libp2p/host/autonat/autonat.py new file mode 100644 index 00000000..acaa8a12 --- /dev/null +++ b/libp2p/host/autonat/autonat.py @@ -0,0 +1,212 @@ +import logging + +from libp2p.custom_types import ( + TProtocol, +) +from libp2p.host.autonat.pb.autonat_pb2 import ( + DialResponse, + Message, + PeerInfo, + Status, + Type, +) +from libp2p.host.basic_host import ( + BasicHost, +) +from libp2p.network.stream.net_stream import ( + NetStream, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.peer.peerstore import ( + IPeerStore, +) + +AUTONAT_PROTOCOL_ID = TProtocol("/ipfs/autonat/1.0.0") + +logger = logging.getLogger("libp2p.host.autonat") + + +class AutoNATStatus: + """ + Status enumeration for the AutoNAT service. + + This class defines the possible states of NAT traversal for a node: + - UNKNOWN (0): The node's NAT status has not been determined yet + - PUBLIC (1): The node is publicly reachable + - PRIVATE (2): The node is behind NAT and not directly reachable + """ + + UNKNOWN = 0 + PUBLIC = 1 + PRIVATE = 2 + + +class AutoNATService: + """ + Service for determining if a node is publicly reachable. + + The AutoNAT service helps nodes determine their NAT status by attempting + to establish connections with other peers. It maintains a record of dial + attempts and their results to classify the node as public or private. + """ + + def __init__(self, host: BasicHost) -> None: + """ + Initialize the AutoNAT service. + + Args: + host (BasicHost): The libp2p host instance that provides networking + capabilities for the AutoNAT service, including peer discovery + and connection management. + """ + self.host = host + self.peerstore: IPeerStore = host.get_peerstore() + self.status = AutoNATStatus.UNKNOWN + self.dial_results: dict[ID, bool] = {} + + async def handle_stream(self, stream: NetStream) -> None: + """ + Handle an incoming stream. + + :param stream: The stream to handle + """ + try: + request_bytes = await stream.read() + request = Message() + request.ParseFromString(request_bytes) + response = await self._handle_request(request) + await stream.write(response.SerializeToString()) + except Exception as e: + logger.error("Error handling AutoNAT stream: %s", str(e)) + finally: + await stream.close() + + async def _handle_request(self, request: bytes | Message) -> Message: + """ + Handle an AutoNAT request. + + Parses and processes incoming AutoNAT requests, routing them to + appropriate handlers based on the message type. + + Args: + request (bytes | Message): The request bytes that need to be parsed + and handled by the AutoNAT service, or a Message object directly. + + Returns: + Message: The response message containing the result of the request. + Returns an error response if the request type is not recognized. + """ + if isinstance(request, bytes): + message = Message() + message.ParseFromString(request) + else: + message = request + + if message.type == Type.Value("DIAL"): + response = await self._handle_dial(message) + return response + + # Handle unknown request type + response = Message() + response.type = Type.Value("DIAL_RESPONSE") + error_response = DialResponse() + error_response.status = Status.E_INTERNAL_ERROR + response.dial_response.CopyFrom(error_response) + return response + + async def _handle_dial(self, message: Message) -> Message: + """ + Handle a DIAL request. + + Processes dial requests by attempting to connect to specified peers + and recording the results of these connection attempts. + + Args: + message (Message): The request message containing the dial request + parameters and peer information. + + Returns: + Message: The response message containing the dial results, including + success/failure status for each attempted peer connection. + """ + response = Message() + response.type = Type.Value("DIAL_RESPONSE") + dial_response = DialResponse() + dial_response.status = Status.OK + + for peer in message.dial.peers: + peer_id = ID(peer.id) + if peer_id in self.dial_results: + success = self.dial_results[peer_id] + else: + success = await self._try_dial(peer_id) + self.dial_results[peer_id] = success + + peer_info = PeerInfo() + peer_info.id = peer_id.to_bytes() + peer_info.addrs.extend(peer.addrs) + peer_info.success = success + dial_response.peers.append(peer_info) + + # Initialize the dial_response field if it doesn't exist + if not hasattr(response, "dial_response"): + response.dial_response = DialResponse() + response.dial_response.CopyFrom(dial_response) + return response + + async def _try_dial(self, peer_id: ID) -> bool: + """ + Try to dial a peer. + + Attempts to establish a connection with a specified peer to test + NAT traversal capabilities. + + Args: + peer_id (ID): The identifier of the peer to attempt to dial for + NAT traversal testing. + + Returns: + bool: True if the dial was successful and a connection could be + established, False if the connection attempt failed. + """ + try: + stream = await self.host.new_stream(peer_id, [AUTONAT_PROTOCOL_ID]) + await stream.close() + return True + except Exception: + return False + + def get_status(self) -> int: + """ + Get the current AutoNAT status. + + Retrieves the current NAT status of the node based on previous + dial attempts and their results. + + Returns: + int: The current status as an integer: + - AutoNATStatus.UNKNOWN (0): Status not yet determined + - AutoNATStatus.PUBLIC (1): Node is publicly reachable + - AutoNATStatus.PRIVATE (2): Node is behind NAT + """ + return self.status + + def update_status(self) -> None: + """ + Update the AutoNAT status based on dial results. + + Analyzes the results of previous dial attempts to determine if the + node is publicly reachable. The node is considered public if at + least two successful dial attempts have been recorded. + """ + if not self.dial_results: + self.status = AutoNATStatus.UNKNOWN + return + + success_count = sum(1 for success in self.dial_results.values() if success) + if success_count >= 2: + self.status = AutoNATStatus.PUBLIC + else: + self.status = AutoNATStatus.PRIVATE diff --git a/libp2p/host/autonat/pb/__init__.py b/libp2p/host/autonat/pb/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libp2p/host/autonat/pb/autonat.proto b/libp2p/host/autonat/pb/autonat.proto new file mode 100644 index 00000000..b9351395 --- /dev/null +++ b/libp2p/host/autonat/pb/autonat.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; + +package autonat.pb; + +// AutoNAT service definition +service AutoNAT { + rpc Dial (Message) returns (Message) {} +} + +// Message types +enum Type { + UNKNOWN = 0; + DIAL = 1; + DIAL_RESPONSE = 2; +} + +// Status codes +enum Status { + OK = 0; + E_DIAL_ERROR = 1; + E_DIAL_REFUSED = 2; + E_DIAL_FAILED = 3; + E_INTERNAL_ERROR = 100; +} + +// Main message +message Message { + Type type = 1; + DialRequest dial = 2; + DialResponse dial_response = 3; +} + +// Dial request +message DialRequest { + repeated PeerInfo peers = 1; +} + +// Dial response +message DialResponse { + Status status = 1; + repeated PeerInfo peers = 2; +} + +// Peer information +message PeerInfo { + bytes id = 1; + repeated bytes addrs = 2; + bool success = 3; +} diff --git a/libp2p/host/autonat/pb/autonat_pb2.py b/libp2p/host/autonat/pb/autonat_pb2.py new file mode 100644 index 00000000..b359edf2 --- /dev/null +++ b/libp2p/host/autonat/pb/autonat_pb2.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: libp2p/host/autonat/pb/autonat.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'libp2p/host/autonat/pb/autonat.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/host/autonat/pb/autonat.proto\x12\nautonat.pb\"\x81\x01\n\x07Message\x12\x1e\n\x04type\x18\x01 \x01(\x0e\x32\x10.autonat.pb.Type\x12%\n\x04\x64ial\x18\x02 \x01(\x0b\x32\x17.autonat.pb.DialRequest\x12/\n\rdial_response\x18\x03 \x01(\x0b\x32\x18.autonat.pb.DialResponse\"2\n\x0b\x44ialRequest\x12#\n\x05peers\x18\x01 \x03(\x0b\x32\x14.autonat.pb.PeerInfo\"W\n\x0c\x44ialResponse\x12\"\n\x06status\x18\x01 \x01(\x0e\x32\x12.autonat.pb.Status\x12#\n\x05peers\x18\x02 \x03(\x0b\x32\x14.autonat.pb.PeerInfo\"6\n\x08PeerInfo\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12\x0f\n\x07success\x18\x03 \x01(\x08*0\n\x04Type\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x08\n\x04\x44IAL\x10\x01\x12\x11\n\rDIAL_RESPONSE\x10\x02*_\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\x10\n\x0c\x45_DIAL_ERROR\x10\x01\x12\x12\n\x0e\x45_DIAL_REFUSED\x10\x02\x12\x11\n\rE_DIAL_FAILED\x10\x03\x12\x14\n\x10\x45_INTERNAL_ERROR\x10\x64\x32=\n\x07\x41utoNAT\x12\x32\n\x04\x44ial\x12\x13.autonat.pb.Message\x1a\x13.autonat.pb.Message\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.host.autonat.pb.autonat_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_TYPE']._serialized_start=381 + _globals['_TYPE']._serialized_end=429 + _globals['_STATUS']._serialized_start=431 + _globals['_STATUS']._serialized_end=526 + _globals['_MESSAGE']._serialized_start=53 + _globals['_MESSAGE']._serialized_end=182 + _globals['_DIALREQUEST']._serialized_start=184 + _globals['_DIALREQUEST']._serialized_end=234 + _globals['_DIALRESPONSE']._serialized_start=236 + _globals['_DIALRESPONSE']._serialized_end=323 + _globals['_PEERINFO']._serialized_start=325 + _globals['_PEERINFO']._serialized_end=379 + _globals['_AUTONAT']._serialized_start=528 + _globals['_AUTONAT']._serialized_end=589 +# @@protoc_insertion_point(module_scope) diff --git a/libp2p/host/autonat/pb/autonat_pb2.pyi b/libp2p/host/autonat/pb/autonat_pb2.pyi new file mode 100644 index 00000000..9053b735 --- /dev/null +++ b/libp2p/host/autonat/pb/autonat_pb2.pyi @@ -0,0 +1,56 @@ +from typing import Any, List, Optional, Union + +class Message: + type: int + dial: Any + dial_response: Any + + def ParseFromString(self, data: bytes) -> None: ... + def SerializeToString(self) -> bytes: ... + @staticmethod + def FromString(data: bytes) -> 'Message': ... + +class DialRequest: + peers: List[Any] + + def ParseFromString(self, data: bytes) -> None: ... + def SerializeToString(self) -> bytes: ... + @staticmethod + def FromString(data: bytes) -> 'DialRequest': ... + +class DialResponse: + status: int + peers: List[Any] + + def ParseFromString(self, data: bytes) -> None: ... + def SerializeToString(self) -> bytes: ... + @staticmethod + def FromString(data: bytes) -> 'DialResponse': ... + +class PeerInfo: + id: bytes + addrs: List[bytes] + success: bool + + def ParseFromString(self, data: bytes) -> None: ... + def SerializeToString(self) -> bytes: ... + @staticmethod + def FromString(data: bytes) -> 'PeerInfo': ... + +class Type: + UNKNOWN: int + DIAL: int + DIAL_RESPONSE: int + + @staticmethod + def Value(name: str) -> int: ... + +class Status: + OK: int + E_DIAL_ERROR: int + E_DIAL_REFUSED: int + E_DIAL_FAILED: int + E_INTERNAL_ERROR: int + + @staticmethod + def Value(name: str) -> int: ... \ No newline at end of file diff --git a/libp2p/host/autonat/pb/autonat_pb2_grpc.py b/libp2p/host/autonat/pb/autonat_pb2_grpc.py new file mode 100644 index 00000000..f43fe67f --- /dev/null +++ b/libp2p/host/autonat/pb/autonat_pb2_grpc.py @@ -0,0 +1,108 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +from typing import Any, Optional + +from . import autonat_pb2 as autonat__pb2 + +GRPC_GENERATED_VERSION = "1.71.0" +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + + _version_not_supported = first_version_is_lower( + GRPC_VERSION, GRPC_GENERATED_VERSION + ) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f"The grpc package installed is at version {GRPC_VERSION}," + + f" but the generated code in autonat_pb2_grpc.py depends on" + + f" grpcio>={GRPC_GENERATED_VERSION}." + + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" + + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." + ) + + +class AutoNATStub: + """AutoNAT service definition""" + + def __init__(self, channel: grpc.Channel) -> None: + """Initialize the AutoNAT stub. + + Args: + ---- + channel (grpc.Channel): The gRPC channel instance that facilitates + communication for the AutoNAT service, providing the underlying + transport mechanism for RPC calls. + + """ + self.Dial = channel.unary_unary( + "/autonat.pb.AutoNAT/Dial", + request_serializer=autonat__pb2.Message.SerializeToString, + response_deserializer=autonat__pb2.Message.FromString, + _registered_method=True, + ) + + +class AutoNATServicer: + """AutoNAT service definition""" + + def Dial(self, request: Any, context: Any) -> Any: + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_AutoNATServicer_to_server(servicer: AutoNATServicer, server: Any) -> None: + rpc_method_handlers = { + "Dial": grpc.unary_unary_rpc_method_handler( + servicer.Dial, + request_deserializer=autonat__pb2.Message.FromString, + response_serializer=autonat__pb2.Message.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "autonat.pb.AutoNAT", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers("autonat.pb.AutoNAT", rpc_method_handlers) + + +# This class is part of an EXPERIMENTAL API. +class AutoNAT: + """AutoNAT service definition""" + + @staticmethod + def Dial( + request: Any, + target: str, + options: tuple[Any, ...] = (), + channel_credentials: Optional[Any] = None, + call_credentials: Optional[Any] = None, + insecure: bool = False, + compression: Optional[Any] = None, + wait_for_ready: Optional[bool] = None, + timeout: Optional[float] = None, + metadata: Optional[list[tuple[str, str]]] = None, + ) -> Any: + return grpc.experimental.unary_unary( + request, + target, + "/autonat.pb.AutoNAT/Dial", + autonat__pb2.Message.SerializeToString, + autonat__pb2.Message.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/libp2p/host/autonat/pb/autonat_pb2_grpc.pyi b/libp2p/host/autonat/pb/autonat_pb2_grpc.pyi new file mode 100644 index 00000000..48929fc1 --- /dev/null +++ b/libp2p/host/autonat/pb/autonat_pb2_grpc.pyi @@ -0,0 +1,26 @@ +from typing import Any, List, Optional, Tuple, Union +import grpc + +class AutoNATStub: + def __init__(self, channel: grpc.Channel) -> None: ... + Dial: Any + +class AutoNATServicer: + def Dial(self, request: Any, context: Any) -> Any: ... + +def add_AutoNATServicer_to_server(servicer: AutoNATServicer, server: Any) -> None: ... + +class AutoNAT: + @staticmethod + def Dial( + request: Any, + target: str, + options: Tuple[Any, ...] = (), + channel_credentials: Optional[Any] = None, + call_credentials: Optional[Any] = None, + insecure: bool = False, + compression: Optional[Any] = None, + wait_for_ready: Optional[bool] = None, + timeout: Optional[float] = None, + metadata: Optional[List[Tuple[str, str]]] = None, + ) -> Any: ... diff --git a/libp2p/host/autonat/pb/generate_proto.py b/libp2p/host/autonat/pb/generate_proto.py new file mode 100755 index 00000000..a5424f4e --- /dev/null +++ b/libp2p/host/autonat/pb/generate_proto.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +import subprocess + + +def generate_proto() -> None: + proto_file = "autonat.proto" + output_dir = "." + + # Ensure protoc is installed + try: + subprocess.run(["protoc", "--version"], check=True, capture_output=True) + except subprocess.CalledProcessError: + print("Error: protoc is not installed. Please install protobuf compiler.") + return + except FileNotFoundError: + print("Error: protoc is not found in PATH. Please install protobuf compiler.") + return + + # Generate Python code + cmd = [ + "protoc", + "--python_out=" + output_dir, + "--grpc_python_out=" + output_dir, + "-I.", + proto_file, + ] + + try: + subprocess.run(cmd, check=True) + print("Successfully generated protobuf code for " + proto_file) + except subprocess.CalledProcessError as e: + print("Error generating protobuf code: " + str(e)) + + +if __name__ == "__main__": + generate_proto() diff --git a/tests/core/host/test_autonat.py b/tests/core/host/test_autonat.py new file mode 100644 index 00000000..fe394745 --- /dev/null +++ b/tests/core/host/test_autonat.py @@ -0,0 +1,240 @@ +from unittest.mock import ( + AsyncMock, + patch, +) + +import pytest + +from libp2p.host.autonat.autonat import ( + AUTONAT_PROTOCOL_ID, + AutoNATService, + AutoNATStatus, +) +from libp2p.host.autonat.pb.autonat_pb2 import ( + DialRequest, + DialResponse, + Message, + PeerInfo, + Status, + Type, +) +from libp2p.network.stream.exceptions import ( + StreamError, +) +from libp2p.network.stream.net_stream import ( + NetStream, +) +from libp2p.peer.id import ( + ID, +) +from tests.utils.factories import ( + HostFactory, +) + + +@pytest.mark.trio +async def test_autonat_service_initialization(): + """Test that the AutoNAT service initializes correctly.""" + async with HostFactory.create_batch_and_listen(1) as hosts: + host = hosts[0] + service = AutoNATService(host) + + assert service.status == AutoNATStatus.UNKNOWN + assert service.dial_results == {} + assert service.host == host + assert service.peerstore == host.get_peerstore() + + +@pytest.mark.trio +async def test_autonat_status_getter(): + """Test that the AutoNAT status getter works correctly.""" + async with HostFactory.create_batch_and_listen(1) as hosts: + host = hosts[0] + service = AutoNATService(host) + + # Testing the initial status + assert service.get_status() == AutoNATStatus.UNKNOWN + + # Testing the status changes + service.status = AutoNATStatus.PUBLIC + assert service.get_status() == AutoNATStatus.PUBLIC + + service.status = AutoNATStatus.PRIVATE + assert service.get_status() == AutoNATStatus.PRIVATE + + +@pytest.mark.trio +async def test_update_status(): + """Test that the AutoNAT status updates correctly based on dial results.""" + async with HostFactory.create_batch_and_listen(1) as hosts: + host = hosts[0] + service = AutoNATService(host) + + # No dial results should result in UNKNOWN status + service.update_status() + assert service.status == AutoNATStatus.UNKNOWN + + # Less than 2 successful dials should result in PRIVATE status + service.dial_results = { + ID("peer1"): True, + ID("peer2"): False, + ID("peer3"): False, + } + service.update_status() + assert service.status == AutoNATStatus.PRIVATE + + # 2 or more successful dials should result in PUBLIC status + service.dial_results = { + ID("peer1"): True, + ID("peer2"): True, + ID("peer3"): False, + } + service.update_status() + assert service.status == AutoNATStatus.PUBLIC + + +@pytest.mark.trio +async def test_try_dial(): + """Test that the try_dial method works correctly.""" + async with HostFactory.create_batch_and_listen(2) as hosts: + host1, host2 = hosts + service = AutoNATService(host1) + peer_id = host2.get_id() + + # Test successful dial + with patch.object( + host1, "new_stream", new_callable=AsyncMock + ) as mock_new_stream: + mock_stream = AsyncMock(spec=NetStream) + mock_new_stream.return_value = mock_stream + + result = await service._try_dial(peer_id) + + assert result is True + mock_new_stream.assert_called_once_with(peer_id, [AUTONAT_PROTOCOL_ID]) + mock_stream.close.assert_called_once() + + # Test failed dial + with patch.object( + host1, "new_stream", new_callable=AsyncMock + ) as mock_new_stream: + mock_new_stream.side_effect = Exception("Connection failed") + + result = await service._try_dial(peer_id) + + assert result is False + mock_new_stream.assert_called_once_with(peer_id, [AUTONAT_PROTOCOL_ID]) + + +@pytest.mark.trio +async def test_handle_dial(): + """Test that the handle_dial method works correctly.""" + async with HostFactory.create_batch_and_listen(2) as hosts: + host1, host2 = hosts + service = AutoNATService(host1) + peer_id = host2.get_id() + + # Create a mock message with a peer to dial + message = Message() + message.type = Type.Value("DIAL") + peer_info = PeerInfo() + peer_info.id = peer_id.to_bytes() + peer_info.addrs.extend([b"/ip4/127.0.0.1/tcp/4001"]) + message.dial.peers.append(peer_info) + + # Mock the _try_dial method + with patch.object( + service, "_try_dial", new_callable=AsyncMock + ) as mock_try_dial: + mock_try_dial.return_value = True + + response = await service._handle_dial(message) + + assert response.type == Type.Value("DIAL_RESPONSE") + assert response.dial_response.status == Status.OK + assert len(response.dial_response.peers) == 1 + assert response.dial_response.peers[0].id == peer_id.to_bytes() + assert response.dial_response.peers[0].success is True + mock_try_dial.assert_called_once_with(peer_id) + + +@pytest.mark.trio +async def test_handle_request(): + """Test that the handle_request method works correctly.""" + async with HostFactory.create_batch_and_listen(1) as hosts: + host = hosts[0] + service = AutoNATService(host) + + # Test handling a DIAL request + message = Message() + message.type = Type.DIAL + dial_request = DialRequest() + peer_info = PeerInfo() + dial_request.peers.append(peer_info) + message.dial.CopyFrom(dial_request) + + with patch.object( + service, "_handle_dial", new_callable=AsyncMock + ) as mock_handle_dial: + mock_handle_dial.return_value = Message() + + response = await service._handle_request(message.SerializeToString()) + + mock_handle_dial.assert_called_once() + assert isinstance(response, Message) + + # Test handling an unknown request type + message = Message() + message.type = Type.UNKNOWN + + response = await service._handle_request(message.SerializeToString()) + + assert isinstance(response, Message) + assert response.type == Type.DIAL_RESPONSE + assert response.dial_response.status == Status.E_INTERNAL_ERROR + + +@pytest.mark.trio +async def test_handle_stream(): + """Test that handle_stream correctly processes stream data.""" + async with HostFactory.create_batch_and_listen(1) as hosts: + host = hosts[0] + autonat_service = AutoNATService(host) + + # Create a mock stream + mock_stream = AsyncMock(spec=NetStream) + + # Create a properly initialized request Message + request = Message() + request.type = Type.DIAL + dial_request = DialRequest() + peer_info = PeerInfo() + peer_info.id = b"peer_id" + peer_info.addrs.append(b"addr1") + dial_request.peers.append(peer_info) + request.dial.CopyFrom(dial_request) + + # Create a properly initialized response Message + response = Message() + response.type = Type.DIAL_RESPONSE + dial_response = DialResponse() + dial_response.status = Status.OK + dial_response.peers.append(peer_info) + response.dial_response.CopyFrom(dial_response) + + # Mock stream read/write and _handle_request + mock_stream.read.return_value = request.SerializeToString() + mock_stream.write.return_value = None + autonat_service._handle_request = AsyncMock(return_value=response) + + # Test successful stream handling + await autonat_service.handle_stream(mock_stream) + mock_stream.read.assert_called_once() + mock_stream.write.assert_called_once_with(response.SerializeToString()) + mock_stream.close.assert_called_once() + + # Test stream error handling + mock_stream.reset_mock() + mock_stream.read.side_effect = StreamError("Stream error") + await autonat_service.handle_stream(mock_stream) + mock_stream.close.assert_called_once()