mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
feat/561-added autonat service
This commit is contained in:
committed by
Paul Robinson
parent
fd893afba6
commit
9655c88788
5
libp2p/host/autonat/__init__.py
Normal file
5
libp2p/host/autonat/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .autonat import (
|
||||
AutoNATService,
|
||||
)
|
||||
|
||||
__all__ = ["AutoNATService"]
|
||||
212
libp2p/host/autonat/autonat.py
Normal file
212
libp2p/host/autonat/autonat.py
Normal file
@ -0,0 +1,212 @@
|
||||
import logging
|
||||
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.host.autonat.pb.autonat_pb2 import (
|
||||
DialResponse,
|
||||
Message,
|
||||
PeerInfo,
|
||||
Status,
|
||||
Type,
|
||||
)
|
||||
from libp2p.host.basic_host import (
|
||||
BasicHost,
|
||||
)
|
||||
from libp2p.network.stream.net_stream import (
|
||||
NetStream,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerstore import (
|
||||
IPeerStore,
|
||||
)
|
||||
|
||||
AUTONAT_PROTOCOL_ID = TProtocol("/ipfs/autonat/1.0.0")
|
||||
|
||||
logger = logging.getLogger("libp2p.host.autonat")
|
||||
|
||||
|
||||
class AutoNATStatus:
|
||||
"""
|
||||
Status enumeration for the AutoNAT service.
|
||||
|
||||
This class defines the possible states of NAT traversal for a node:
|
||||
- UNKNOWN (0): The node's NAT status has not been determined yet
|
||||
- PUBLIC (1): The node is publicly reachable
|
||||
- PRIVATE (2): The node is behind NAT and not directly reachable
|
||||
"""
|
||||
|
||||
UNKNOWN = 0
|
||||
PUBLIC = 1
|
||||
PRIVATE = 2
|
||||
|
||||
|
||||
class AutoNATService:
|
||||
"""
|
||||
Service for determining if a node is publicly reachable.
|
||||
|
||||
The AutoNAT service helps nodes determine their NAT status by attempting
|
||||
to establish connections with other peers. It maintains a record of dial
|
||||
attempts and their results to classify the node as public or private.
|
||||
"""
|
||||
|
||||
def __init__(self, host: BasicHost) -> None:
|
||||
"""
|
||||
Initialize the AutoNAT service.
|
||||
|
||||
Args:
|
||||
host (BasicHost): The libp2p host instance that provides networking
|
||||
capabilities for the AutoNAT service, including peer discovery
|
||||
and connection management.
|
||||
"""
|
||||
self.host = host
|
||||
self.peerstore: IPeerStore = host.get_peerstore()
|
||||
self.status = AutoNATStatus.UNKNOWN
|
||||
self.dial_results: dict[ID, bool] = {}
|
||||
|
||||
async def handle_stream(self, stream: NetStream) -> None:
|
||||
"""
|
||||
Handle an incoming stream.
|
||||
|
||||
:param stream: The stream to handle
|
||||
"""
|
||||
try:
|
||||
request_bytes = await stream.read()
|
||||
request = Message()
|
||||
request.ParseFromString(request_bytes)
|
||||
response = await self._handle_request(request)
|
||||
await stream.write(response.SerializeToString())
|
||||
except Exception as e:
|
||||
logger.error("Error handling AutoNAT stream: %s", str(e))
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
async def _handle_request(self, request: bytes | Message) -> Message:
|
||||
"""
|
||||
Handle an AutoNAT request.
|
||||
|
||||
Parses and processes incoming AutoNAT requests, routing them to
|
||||
appropriate handlers based on the message type.
|
||||
|
||||
Args:
|
||||
request (bytes | Message): The request bytes that need to be parsed
|
||||
and handled by the AutoNAT service, or a Message object directly.
|
||||
|
||||
Returns:
|
||||
Message: The response message containing the result of the request.
|
||||
Returns an error response if the request type is not recognized.
|
||||
"""
|
||||
if isinstance(request, bytes):
|
||||
message = Message()
|
||||
message.ParseFromString(request)
|
||||
else:
|
||||
message = request
|
||||
|
||||
if message.type == Type.Value("DIAL"):
|
||||
response = await self._handle_dial(message)
|
||||
return response
|
||||
|
||||
# Handle unknown request type
|
||||
response = Message()
|
||||
response.type = Type.Value("DIAL_RESPONSE")
|
||||
error_response = DialResponse()
|
||||
error_response.status = Status.E_INTERNAL_ERROR
|
||||
response.dial_response.CopyFrom(error_response)
|
||||
return response
|
||||
|
||||
async def _handle_dial(self, message: Message) -> Message:
|
||||
"""
|
||||
Handle a DIAL request.
|
||||
|
||||
Processes dial requests by attempting to connect to specified peers
|
||||
and recording the results of these connection attempts.
|
||||
|
||||
Args:
|
||||
message (Message): The request message containing the dial request
|
||||
parameters and peer information.
|
||||
|
||||
Returns:
|
||||
Message: The response message containing the dial results, including
|
||||
success/failure status for each attempted peer connection.
|
||||
"""
|
||||
response = Message()
|
||||
response.type = Type.Value("DIAL_RESPONSE")
|
||||
dial_response = DialResponse()
|
||||
dial_response.status = Status.OK
|
||||
|
||||
for peer in message.dial.peers:
|
||||
peer_id = ID(peer.id)
|
||||
if peer_id in self.dial_results:
|
||||
success = self.dial_results[peer_id]
|
||||
else:
|
||||
success = await self._try_dial(peer_id)
|
||||
self.dial_results[peer_id] = success
|
||||
|
||||
peer_info = PeerInfo()
|
||||
peer_info.id = peer_id.to_bytes()
|
||||
peer_info.addrs.extend(peer.addrs)
|
||||
peer_info.success = success
|
||||
dial_response.peers.append(peer_info)
|
||||
|
||||
# Initialize the dial_response field if it doesn't exist
|
||||
if not hasattr(response, "dial_response"):
|
||||
response.dial_response = DialResponse()
|
||||
response.dial_response.CopyFrom(dial_response)
|
||||
return response
|
||||
|
||||
async def _try_dial(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Try to dial a peer.
|
||||
|
||||
Attempts to establish a connection with a specified peer to test
|
||||
NAT traversal capabilities.
|
||||
|
||||
Args:
|
||||
peer_id (ID): The identifier of the peer to attempt to dial for
|
||||
NAT traversal testing.
|
||||
|
||||
Returns:
|
||||
bool: True if the dial was successful and a connection could be
|
||||
established, False if the connection attempt failed.
|
||||
"""
|
||||
try:
|
||||
stream = await self.host.new_stream(peer_id, [AUTONAT_PROTOCOL_ID])
|
||||
await stream.close()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_status(self) -> int:
|
||||
"""
|
||||
Get the current AutoNAT status.
|
||||
|
||||
Retrieves the current NAT status of the node based on previous
|
||||
dial attempts and their results.
|
||||
|
||||
Returns:
|
||||
int: The current status as an integer:
|
||||
- AutoNATStatus.UNKNOWN (0): Status not yet determined
|
||||
- AutoNATStatus.PUBLIC (1): Node is publicly reachable
|
||||
- AutoNATStatus.PRIVATE (2): Node is behind NAT
|
||||
"""
|
||||
return self.status
|
||||
|
||||
def update_status(self) -> None:
|
||||
"""
|
||||
Update the AutoNAT status based on dial results.
|
||||
|
||||
Analyzes the results of previous dial attempts to determine if the
|
||||
node is publicly reachable. The node is considered public if at
|
||||
least two successful dial attempts have been recorded.
|
||||
"""
|
||||
if not self.dial_results:
|
||||
self.status = AutoNATStatus.UNKNOWN
|
||||
return
|
||||
|
||||
success_count = sum(1 for success in self.dial_results.values() if success)
|
||||
if success_count >= 2:
|
||||
self.status = AutoNATStatus.PUBLIC
|
||||
else:
|
||||
self.status = AutoNATStatus.PRIVATE
|
||||
0
libp2p/host/autonat/pb/__init__.py
Normal file
0
libp2p/host/autonat/pb/__init__.py
Normal file
49
libp2p/host/autonat/pb/autonat.proto
Normal file
49
libp2p/host/autonat/pb/autonat.proto
Normal file
@ -0,0 +1,49 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package autonat.pb;
|
||||
|
||||
// AutoNAT service definition
|
||||
service AutoNAT {
|
||||
rpc Dial (Message) returns (Message) {}
|
||||
}
|
||||
|
||||
// Message types
|
||||
enum Type {
|
||||
UNKNOWN = 0;
|
||||
DIAL = 1;
|
||||
DIAL_RESPONSE = 2;
|
||||
}
|
||||
|
||||
// Status codes
|
||||
enum Status {
|
||||
OK = 0;
|
||||
E_DIAL_ERROR = 1;
|
||||
E_DIAL_REFUSED = 2;
|
||||
E_DIAL_FAILED = 3;
|
||||
E_INTERNAL_ERROR = 100;
|
||||
}
|
||||
|
||||
// Main message
|
||||
message Message {
|
||||
Type type = 1;
|
||||
DialRequest dial = 2;
|
||||
DialResponse dial_response = 3;
|
||||
}
|
||||
|
||||
// Dial request
|
||||
message DialRequest {
|
||||
repeated PeerInfo peers = 1;
|
||||
}
|
||||
|
||||
// Dial response
|
||||
message DialResponse {
|
||||
Status status = 1;
|
||||
repeated PeerInfo peers = 2;
|
||||
}
|
||||
|
||||
// Peer information
|
||||
message PeerInfo {
|
||||
bytes id = 1;
|
||||
repeated bytes addrs = 2;
|
||||
bool success = 3;
|
||||
}
|
||||
48
libp2p/host/autonat/pb/autonat_pb2.py
Normal file
48
libp2p/host/autonat/pb/autonat_pb2.py
Normal file
@ -0,0 +1,48 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: libp2p/host/autonat/pb/autonat.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'libp2p/host/autonat/pb/autonat.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/host/autonat/pb/autonat.proto\x12\nautonat.pb\"\x81\x01\n\x07Message\x12\x1e\n\x04type\x18\x01 \x01(\x0e\x32\x10.autonat.pb.Type\x12%\n\x04\x64ial\x18\x02 \x01(\x0b\x32\x17.autonat.pb.DialRequest\x12/\n\rdial_response\x18\x03 \x01(\x0b\x32\x18.autonat.pb.DialResponse\"2\n\x0b\x44ialRequest\x12#\n\x05peers\x18\x01 \x03(\x0b\x32\x14.autonat.pb.PeerInfo\"W\n\x0c\x44ialResponse\x12\"\n\x06status\x18\x01 \x01(\x0e\x32\x12.autonat.pb.Status\x12#\n\x05peers\x18\x02 \x03(\x0b\x32\x14.autonat.pb.PeerInfo\"6\n\x08PeerInfo\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12\x0f\n\x07success\x18\x03 \x01(\x08*0\n\x04Type\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x08\n\x04\x44IAL\x10\x01\x12\x11\n\rDIAL_RESPONSE\x10\x02*_\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\x10\n\x0c\x45_DIAL_ERROR\x10\x01\x12\x12\n\x0e\x45_DIAL_REFUSED\x10\x02\x12\x11\n\rE_DIAL_FAILED\x10\x03\x12\x14\n\x10\x45_INTERNAL_ERROR\x10\x64\x32=\n\x07\x41utoNAT\x12\x32\n\x04\x44ial\x12\x13.autonat.pb.Message\x1a\x13.autonat.pb.Message\"\x00\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.host.autonat.pb.autonat_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TYPE']._serialized_start=381
|
||||
_globals['_TYPE']._serialized_end=429
|
||||
_globals['_STATUS']._serialized_start=431
|
||||
_globals['_STATUS']._serialized_end=526
|
||||
_globals['_MESSAGE']._serialized_start=53
|
||||
_globals['_MESSAGE']._serialized_end=182
|
||||
_globals['_DIALREQUEST']._serialized_start=184
|
||||
_globals['_DIALREQUEST']._serialized_end=234
|
||||
_globals['_DIALRESPONSE']._serialized_start=236
|
||||
_globals['_DIALRESPONSE']._serialized_end=323
|
||||
_globals['_PEERINFO']._serialized_start=325
|
||||
_globals['_PEERINFO']._serialized_end=379
|
||||
_globals['_AUTONAT']._serialized_start=528
|
||||
_globals['_AUTONAT']._serialized_end=589
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
56
libp2p/host/autonat/pb/autonat_pb2.pyi
Normal file
56
libp2p/host/autonat/pb/autonat_pb2.pyi
Normal file
@ -0,0 +1,56 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
class Message:
|
||||
type: int
|
||||
dial: Any
|
||||
dial_response: Any
|
||||
|
||||
def ParseFromString(self, data: bytes) -> None: ...
|
||||
def SerializeToString(self) -> bytes: ...
|
||||
@staticmethod
|
||||
def FromString(data: bytes) -> 'Message': ...
|
||||
|
||||
class DialRequest:
|
||||
peers: List[Any]
|
||||
|
||||
def ParseFromString(self, data: bytes) -> None: ...
|
||||
def SerializeToString(self) -> bytes: ...
|
||||
@staticmethod
|
||||
def FromString(data: bytes) -> 'DialRequest': ...
|
||||
|
||||
class DialResponse:
|
||||
status: int
|
||||
peers: List[Any]
|
||||
|
||||
def ParseFromString(self, data: bytes) -> None: ...
|
||||
def SerializeToString(self) -> bytes: ...
|
||||
@staticmethod
|
||||
def FromString(data: bytes) -> 'DialResponse': ...
|
||||
|
||||
class PeerInfo:
|
||||
id: bytes
|
||||
addrs: List[bytes]
|
||||
success: bool
|
||||
|
||||
def ParseFromString(self, data: bytes) -> None: ...
|
||||
def SerializeToString(self) -> bytes: ...
|
||||
@staticmethod
|
||||
def FromString(data: bytes) -> 'PeerInfo': ...
|
||||
|
||||
class Type:
|
||||
UNKNOWN: int
|
||||
DIAL: int
|
||||
DIAL_RESPONSE: int
|
||||
|
||||
@staticmethod
|
||||
def Value(name: str) -> int: ...
|
||||
|
||||
class Status:
|
||||
OK: int
|
||||
E_DIAL_ERROR: int
|
||||
E_DIAL_REFUSED: int
|
||||
E_DIAL_FAILED: int
|
||||
E_INTERNAL_ERROR: int
|
||||
|
||||
@staticmethod
|
||||
def Value(name: str) -> int: ...
|
||||
108
libp2p/host/autonat/pb/autonat_pb2_grpc.py
Normal file
108
libp2p/host/autonat/pb/autonat_pb2_grpc.py
Normal file
@ -0,0 +1,108 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
from typing import Any, Optional
|
||||
|
||||
from . import autonat_pb2 as autonat__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = "1.71.0"
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
|
||||
_version_not_supported = first_version_is_lower(
|
||||
GRPC_VERSION, GRPC_GENERATED_VERSION
|
||||
)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f"The grpc package installed is at version {GRPC_VERSION},"
|
||||
+ f" but the generated code in autonat_pb2_grpc.py depends on"
|
||||
+ f" grpcio>={GRPC_GENERATED_VERSION}."
|
||||
+ f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}"
|
||||
+ f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}."
|
||||
)
|
||||
|
||||
|
||||
class AutoNATStub:
|
||||
"""AutoNAT service definition"""
|
||||
|
||||
def __init__(self, channel: grpc.Channel) -> None:
|
||||
"""Initialize the AutoNAT stub.
|
||||
|
||||
Args:
|
||||
----
|
||||
channel (grpc.Channel): The gRPC channel instance that facilitates
|
||||
communication for the AutoNAT service, providing the underlying
|
||||
transport mechanism for RPC calls.
|
||||
|
||||
"""
|
||||
self.Dial = channel.unary_unary(
|
||||
"/autonat.pb.AutoNAT/Dial",
|
||||
request_serializer=autonat__pb2.Message.SerializeToString,
|
||||
response_deserializer=autonat__pb2.Message.FromString,
|
||||
_registered_method=True,
|
||||
)
|
||||
|
||||
|
||||
class AutoNATServicer:
|
||||
"""AutoNAT service definition"""
|
||||
|
||||
def Dial(self, request: Any, context: Any) -> Any:
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details("Method not implemented!")
|
||||
raise NotImplementedError("Method not implemented!")
|
||||
|
||||
|
||||
def add_AutoNATServicer_to_server(servicer: AutoNATServicer, server: Any) -> None:
|
||||
rpc_method_handlers = {
|
||||
"Dial": grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Dial,
|
||||
request_deserializer=autonat__pb2.Message.FromString,
|
||||
response_serializer=autonat__pb2.Message.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
"autonat.pb.AutoNAT", rpc_method_handlers
|
||||
)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers("autonat.pb.AutoNAT", rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AutoNAT:
|
||||
"""AutoNAT service definition"""
|
||||
|
||||
@staticmethod
|
||||
def Dial(
|
||||
request: Any,
|
||||
target: str,
|
||||
options: tuple[Any, ...] = (),
|
||||
channel_credentials: Optional[Any] = None,
|
||||
call_credentials: Optional[Any] = None,
|
||||
insecure: bool = False,
|
||||
compression: Optional[Any] = None,
|
||||
wait_for_ready: Optional[bool] = None,
|
||||
timeout: Optional[float] = None,
|
||||
metadata: Optional[list[tuple[str, str]]] = None,
|
||||
) -> Any:
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
"/autonat.pb.AutoNAT/Dial",
|
||||
autonat__pb2.Message.SerializeToString,
|
||||
autonat__pb2.Message.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
)
|
||||
26
libp2p/host/autonat/pb/autonat_pb2_grpc.pyi
Normal file
26
libp2p/host/autonat/pb/autonat_pb2_grpc.pyi
Normal file
@ -0,0 +1,26 @@
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
import grpc
|
||||
|
||||
class AutoNATStub:
|
||||
def __init__(self, channel: grpc.Channel) -> None: ...
|
||||
Dial: Any
|
||||
|
||||
class AutoNATServicer:
|
||||
def Dial(self, request: Any, context: Any) -> Any: ...
|
||||
|
||||
def add_AutoNATServicer_to_server(servicer: AutoNATServicer, server: Any) -> None: ...
|
||||
|
||||
class AutoNAT:
|
||||
@staticmethod
|
||||
def Dial(
|
||||
request: Any,
|
||||
target: str,
|
||||
options: Tuple[Any, ...] = (),
|
||||
channel_credentials: Optional[Any] = None,
|
||||
call_credentials: Optional[Any] = None,
|
||||
insecure: bool = False,
|
||||
compression: Optional[Any] = None,
|
||||
wait_for_ready: Optional[bool] = None,
|
||||
timeout: Optional[float] = None,
|
||||
metadata: Optional[List[Tuple[str, str]]] = None,
|
||||
) -> Any: ...
|
||||
36
libp2p/host/autonat/pb/generate_proto.py
Executable file
36
libp2p/host/autonat/pb/generate_proto.py
Executable file
@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
import subprocess
|
||||
|
||||
|
||||
def generate_proto() -> None:
|
||||
proto_file = "autonat.proto"
|
||||
output_dir = "."
|
||||
|
||||
# Ensure protoc is installed
|
||||
try:
|
||||
subprocess.run(["protoc", "--version"], check=True, capture_output=True)
|
||||
except subprocess.CalledProcessError:
|
||||
print("Error: protoc is not installed. Please install protobuf compiler.")
|
||||
return
|
||||
except FileNotFoundError:
|
||||
print("Error: protoc is not found in PATH. Please install protobuf compiler.")
|
||||
return
|
||||
|
||||
# Generate Python code
|
||||
cmd = [
|
||||
"protoc",
|
||||
"--python_out=" + output_dir,
|
||||
"--grpc_python_out=" + output_dir,
|
||||
"-I.",
|
||||
proto_file,
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
print("Successfully generated protobuf code for " + proto_file)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Error generating protobuf code: " + str(e))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_proto()
|
||||
240
tests/core/host/test_autonat.py
Normal file
240
tests/core/host/test_autonat.py
Normal file
@ -0,0 +1,240 @@
|
||||
from unittest.mock import (
|
||||
AsyncMock,
|
||||
patch,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.host.autonat.autonat import (
|
||||
AUTONAT_PROTOCOL_ID,
|
||||
AutoNATService,
|
||||
AutoNATStatus,
|
||||
)
|
||||
from libp2p.host.autonat.pb.autonat_pb2 import (
|
||||
DialRequest,
|
||||
DialResponse,
|
||||
Message,
|
||||
PeerInfo,
|
||||
Status,
|
||||
Type,
|
||||
)
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamError,
|
||||
)
|
||||
from libp2p.network.stream.net_stream import (
|
||||
NetStream,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from tests.utils.factories import (
|
||||
HostFactory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_autonat_service_initialization():
|
||||
"""Test that the AutoNAT service initializes correctly."""
|
||||
async with HostFactory.create_batch_and_listen(1) as hosts:
|
||||
host = hosts[0]
|
||||
service = AutoNATService(host)
|
||||
|
||||
assert service.status == AutoNATStatus.UNKNOWN
|
||||
assert service.dial_results == {}
|
||||
assert service.host == host
|
||||
assert service.peerstore == host.get_peerstore()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_autonat_status_getter():
|
||||
"""Test that the AutoNAT status getter works correctly."""
|
||||
async with HostFactory.create_batch_and_listen(1) as hosts:
|
||||
host = hosts[0]
|
||||
service = AutoNATService(host)
|
||||
|
||||
# Testing the initial status
|
||||
assert service.get_status() == AutoNATStatus.UNKNOWN
|
||||
|
||||
# Testing the status changes
|
||||
service.status = AutoNATStatus.PUBLIC
|
||||
assert service.get_status() == AutoNATStatus.PUBLIC
|
||||
|
||||
service.status = AutoNATStatus.PRIVATE
|
||||
assert service.get_status() == AutoNATStatus.PRIVATE
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_update_status():
|
||||
"""Test that the AutoNAT status updates correctly based on dial results."""
|
||||
async with HostFactory.create_batch_and_listen(1) as hosts:
|
||||
host = hosts[0]
|
||||
service = AutoNATService(host)
|
||||
|
||||
# No dial results should result in UNKNOWN status
|
||||
service.update_status()
|
||||
assert service.status == AutoNATStatus.UNKNOWN
|
||||
|
||||
# Less than 2 successful dials should result in PRIVATE status
|
||||
service.dial_results = {
|
||||
ID("peer1"): True,
|
||||
ID("peer2"): False,
|
||||
ID("peer3"): False,
|
||||
}
|
||||
service.update_status()
|
||||
assert service.status == AutoNATStatus.PRIVATE
|
||||
|
||||
# 2 or more successful dials should result in PUBLIC status
|
||||
service.dial_results = {
|
||||
ID("peer1"): True,
|
||||
ID("peer2"): True,
|
||||
ID("peer3"): False,
|
||||
}
|
||||
service.update_status()
|
||||
assert service.status == AutoNATStatus.PUBLIC
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_try_dial():
|
||||
"""Test that the try_dial method works correctly."""
|
||||
async with HostFactory.create_batch_and_listen(2) as hosts:
|
||||
host1, host2 = hosts
|
||||
service = AutoNATService(host1)
|
||||
peer_id = host2.get_id()
|
||||
|
||||
# Test successful dial
|
||||
with patch.object(
|
||||
host1, "new_stream", new_callable=AsyncMock
|
||||
) as mock_new_stream:
|
||||
mock_stream = AsyncMock(spec=NetStream)
|
||||
mock_new_stream.return_value = mock_stream
|
||||
|
||||
result = await service._try_dial(peer_id)
|
||||
|
||||
assert result is True
|
||||
mock_new_stream.assert_called_once_with(peer_id, [AUTONAT_PROTOCOL_ID])
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
# Test failed dial
|
||||
with patch.object(
|
||||
host1, "new_stream", new_callable=AsyncMock
|
||||
) as mock_new_stream:
|
||||
mock_new_stream.side_effect = Exception("Connection failed")
|
||||
|
||||
result = await service._try_dial(peer_id)
|
||||
|
||||
assert result is False
|
||||
mock_new_stream.assert_called_once_with(peer_id, [AUTONAT_PROTOCOL_ID])
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_dial():
|
||||
"""Test that the handle_dial method works correctly."""
|
||||
async with HostFactory.create_batch_and_listen(2) as hosts:
|
||||
host1, host2 = hosts
|
||||
service = AutoNATService(host1)
|
||||
peer_id = host2.get_id()
|
||||
|
||||
# Create a mock message with a peer to dial
|
||||
message = Message()
|
||||
message.type = Type.Value("DIAL")
|
||||
peer_info = PeerInfo()
|
||||
peer_info.id = peer_id.to_bytes()
|
||||
peer_info.addrs.extend([b"/ip4/127.0.0.1/tcp/4001"])
|
||||
message.dial.peers.append(peer_info)
|
||||
|
||||
# Mock the _try_dial method
|
||||
with patch.object(
|
||||
service, "_try_dial", new_callable=AsyncMock
|
||||
) as mock_try_dial:
|
||||
mock_try_dial.return_value = True
|
||||
|
||||
response = await service._handle_dial(message)
|
||||
|
||||
assert response.type == Type.Value("DIAL_RESPONSE")
|
||||
assert response.dial_response.status == Status.OK
|
||||
assert len(response.dial_response.peers) == 1
|
||||
assert response.dial_response.peers[0].id == peer_id.to_bytes()
|
||||
assert response.dial_response.peers[0].success is True
|
||||
mock_try_dial.assert_called_once_with(peer_id)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_request():
|
||||
"""Test that the handle_request method works correctly."""
|
||||
async with HostFactory.create_batch_and_listen(1) as hosts:
|
||||
host = hosts[0]
|
||||
service = AutoNATService(host)
|
||||
|
||||
# Test handling a DIAL request
|
||||
message = Message()
|
||||
message.type = Type.DIAL
|
||||
dial_request = DialRequest()
|
||||
peer_info = PeerInfo()
|
||||
dial_request.peers.append(peer_info)
|
||||
message.dial.CopyFrom(dial_request)
|
||||
|
||||
with patch.object(
|
||||
service, "_handle_dial", new_callable=AsyncMock
|
||||
) as mock_handle_dial:
|
||||
mock_handle_dial.return_value = Message()
|
||||
|
||||
response = await service._handle_request(message.SerializeToString())
|
||||
|
||||
mock_handle_dial.assert_called_once()
|
||||
assert isinstance(response, Message)
|
||||
|
||||
# Test handling an unknown request type
|
||||
message = Message()
|
||||
message.type = Type.UNKNOWN
|
||||
|
||||
response = await service._handle_request(message.SerializeToString())
|
||||
|
||||
assert isinstance(response, Message)
|
||||
assert response.type == Type.DIAL_RESPONSE
|
||||
assert response.dial_response.status == Status.E_INTERNAL_ERROR
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_stream():
|
||||
"""Test that handle_stream correctly processes stream data."""
|
||||
async with HostFactory.create_batch_and_listen(1) as hosts:
|
||||
host = hosts[0]
|
||||
autonat_service = AutoNATService(host)
|
||||
|
||||
# Create a mock stream
|
||||
mock_stream = AsyncMock(spec=NetStream)
|
||||
|
||||
# Create a properly initialized request Message
|
||||
request = Message()
|
||||
request.type = Type.DIAL
|
||||
dial_request = DialRequest()
|
||||
peer_info = PeerInfo()
|
||||
peer_info.id = b"peer_id"
|
||||
peer_info.addrs.append(b"addr1")
|
||||
dial_request.peers.append(peer_info)
|
||||
request.dial.CopyFrom(dial_request)
|
||||
|
||||
# Create a properly initialized response Message
|
||||
response = Message()
|
||||
response.type = Type.DIAL_RESPONSE
|
||||
dial_response = DialResponse()
|
||||
dial_response.status = Status.OK
|
||||
dial_response.peers.append(peer_info)
|
||||
response.dial_response.CopyFrom(dial_response)
|
||||
|
||||
# Mock stream read/write and _handle_request
|
||||
mock_stream.read.return_value = request.SerializeToString()
|
||||
mock_stream.write.return_value = None
|
||||
autonat_service._handle_request = AsyncMock(return_value=response)
|
||||
|
||||
# Test successful stream handling
|
||||
await autonat_service.handle_stream(mock_stream)
|
||||
mock_stream.read.assert_called_once()
|
||||
mock_stream.write.assert_called_once_with(response.SerializeToString())
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
# Test stream error handling
|
||||
mock_stream.reset_mock()
|
||||
mock_stream.read.side_effect = StreamError("Stream error")
|
||||
await autonat_service.handle_stream(mock_stream)
|
||||
mock_stream.close.assert_called_once()
|
||||
Reference in New Issue
Block a user