mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-04-28 16:28:01 +00:00
Compare commits
15 Commits
9ed44f5fa3
...
varun-r-ma
| Author | SHA1 | Date | |
|---|---|---|---|
| af61523c87 | |||
| d2fdf70692 | |||
| 1ea50a3cf3 | |||
| f4247faa51 | |||
| 92e79bbb3f | |||
| eb3121b818 | |||
| 787648177f | |||
| fc9b28910a | |||
| 26d0ed2d81 | |||
| 618aff9368 | |||
| 32e545d9c7 | |||
| e712e6c0c4 | |||
| b143c96abd | |||
| 678b920992 | |||
| cb11f076c8 |
8
Makefile
8
Makefile
@ -60,6 +60,7 @@ PB = libp2p/crypto/pb/crypto.proto \
|
||||
libp2p/identity/identify/pb/identify.proto \
|
||||
libp2p/host/autonat/pb/autonat.proto \
|
||||
libp2p/relay/circuit_v2/pb/circuit.proto \
|
||||
libp2p/relay/circuit_v2/pb/dcutr.proto \
|
||||
libp2p/kad_dht/pb/kademlia.proto
|
||||
|
||||
PY = $(PB:.proto=_pb2.py)
|
||||
@ -68,6 +69,8 @@ PYI = $(PB:.proto=_pb2.pyi)
|
||||
## Set default to `protobufs`, otherwise `format` is called when typing only `make`
|
||||
all: protobufs
|
||||
|
||||
.PHONY: protobufs clean-proto
|
||||
|
||||
protobufs: $(PY)
|
||||
|
||||
%_pb2.py: %.proto
|
||||
@ -76,6 +79,11 @@ protobufs: $(PY)
|
||||
clean-proto:
|
||||
rm -f $(PY) $(PYI)
|
||||
|
||||
# Force protobuf regeneration by making them always out of date
|
||||
$(PY): FORCE
|
||||
|
||||
FORCE:
|
||||
|
||||
# docs commands
|
||||
|
||||
docs: check-docs
|
||||
|
||||
48
README.md
48
README.md
@ -34,19 +34,19 @@ ______________________________________________________________________
|
||||
| -------------------------------------- | :--------: | :---------------------------------------------------------------------------------: |
|
||||
| **`libp2p-tcp`** | ✅ | [source](https://github.com/libp2p/py-libp2p/blob/main/libp2p/transport/tcp/tcp.py) |
|
||||
| **`libp2p-quic`** | 🌱 | |
|
||||
| **`libp2p-websocket`** | ❌ | |
|
||||
| **`libp2p-webrtc-browser-to-server`** | ❌ | |
|
||||
| **`libp2p-webrtc-private-to-private`** | ❌ | |
|
||||
| **`libp2p-websocket`** | 🌱 | |
|
||||
| **`libp2p-webrtc-browser-to-server`** | 🌱 | |
|
||||
| **`libp2p-webrtc-private-to-private`** | 🌱 | |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### NAT Traversal
|
||||
|
||||
| **NAT Traversal** | **Status** |
|
||||
| ----------------------------- | :--------: |
|
||||
| **`libp2p-circuit-relay-v2`** | ❌ |
|
||||
| **`libp2p-autonat`** | ❌ |
|
||||
| **`libp2p-hole-punching`** | ❌ |
|
||||
| **NAT Traversal** | **Status** | **Source** |
|
||||
| ----------------------------- | :--------: | :-----------------------------------------------------------------------------: |
|
||||
| **`libp2p-circuit-relay-v2`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/relay/circuit_v2) |
|
||||
| **`libp2p-autonat`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/host/autonat) |
|
||||
| **`libp2p-hole-punching`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/relay/circuit_v2) |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
@ -54,27 +54,27 @@ ______________________________________________________________________
|
||||
|
||||
| **Secure Communication** | **Status** | **Source** |
|
||||
| ------------------------ | :--------: | :---------------------------------------------------------------------------: |
|
||||
| **`libp2p-noise`** | 🌱 | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/security/noise) |
|
||||
| **`libp2p-tls`** | ❌ | |
|
||||
| **`libp2p-noise`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/security/noise) |
|
||||
| **`libp2p-tls`** | 🌱 | |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### Discovery
|
||||
|
||||
| **Discovery** | **Status** |
|
||||
| -------------------- | :--------: |
|
||||
| **`bootstrap`** | ❌ |
|
||||
| **`random-walk`** | ❌ |
|
||||
| **`mdns-discovery`** | ❌ |
|
||||
| **`rendezvous`** | ❌ |
|
||||
| **Discovery** | **Status** | **Source** |
|
||||
| -------------------- | :--------: | :--------------------------------------------------------------------------------: |
|
||||
| **`bootstrap`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/bootstrap) |
|
||||
| **`random-walk`** | 🌱 | |
|
||||
| **`mdns-discovery`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/mdns) |
|
||||
| **`rendezvous`** | 🌱 | |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### Peer Routing
|
||||
|
||||
| **Peer Routing** | **Status** |
|
||||
| -------------------- | :--------: |
|
||||
| **`libp2p-kad-dht`** | ❌ |
|
||||
| **Peer Routing** | **Status** | **Source** |
|
||||
| -------------------- | :--------: | :--------------------------------------------------------------------: |
|
||||
| **`libp2p-kad-dht`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/kad_dht) |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
@ -89,10 +89,10 @@ ______________________________________________________________________
|
||||
|
||||
### Stream Muxers
|
||||
|
||||
| **Stream Muxers** | **Status** | **Status** |
|
||||
| ------------------ | :--------: | :----------------------------------------------------------------------------------------: |
|
||||
| **`libp2p-yamux`** | 🌱 | |
|
||||
| **`libp2p-mplex`** | 🛠️ | [source](https://github.com/libp2p/py-libp2p/blob/main/libp2p/stream_muxer/mplex/mplex.py) |
|
||||
| **Stream Muxers** | **Status** | **Source** |
|
||||
| ------------------ | :--------: | :-------------------------------------------------------------------------------: |
|
||||
| **`libp2p-yamux`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/stream_muxer/yamux) |
|
||||
| **`libp2p-mplex`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/stream_muxer/mplex) |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
@ -100,7 +100,7 @@ ______________________________________________________________________
|
||||
|
||||
| **Storage** | **Status** |
|
||||
| ------------------- | :--------: |
|
||||
| **`libp2p-record`** | ❌ |
|
||||
| **`libp2p-record`** | 🌱 |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
|
||||
@ -357,6 +357,14 @@ class INetConn(Closer):
|
||||
:return: A tuple containing instances of INetStream.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_transport_addresses(self) -> list[Multiaddr]:
|
||||
"""
|
||||
Retrieve the transport addresses used by this connection.
|
||||
|
||||
:return: A list of multiaddresses used by the transport.
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------- peermetadata interface.py --------------------------
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1dlibp2p/crypto/pb/crypto.proto\x12\tcrypto.pb\"?\n\tPublicKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c\"@\n\nPrivateKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c*G\n\x07KeyType\x12\x07\n\x03RSA\x10\x00\x12\x0b\n\x07\x45\x64\x32\x35\x35\x31\x39\x10\x01\x12\r\n\tSecp256k1\x10\x02\x12\t\n\x05\x45\x43\x44SA\x10\x03\x12\x0c\n\x08\x45\x43\x43_P256\x10\x04')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1dlibp2p/crypto/pb/crypto.proto\x12\tcrypto.pb\"?\n\tPublicKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c\"@\n\nPrivateKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c*S\n\x07KeyType\x12\x07\n\x03RSA\x10\x00\x12\x0b\n\x07\x45\x64\x32\x35\x35\x31\x39\x10\x01\x12\r\n\tSecp256k1\x10\x02\x12\t\n\x05\x45\x43\x44SA\x10\x03\x12\x0c\n\x08\x45\x43\x43_P256\x10\x04\x12\n\n\x06X25519\x10\x05')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.crypto.pb.crypto_pb2', globals())
|
||||
@ -21,7 +21,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_KEYTYPE._serialized_start=175
|
||||
_KEYTYPE._serialized_end=246
|
||||
_KEYTYPE._serialized_end=258
|
||||
_PUBLICKEY._serialized_start=44
|
||||
_PUBLICKEY._serialized_end=107
|
||||
_PRIVATEKEY._serialized_start=109
|
||||
|
||||
@ -28,6 +28,7 @@ class _KeyTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTy
|
||||
Secp256k1: _KeyType.ValueType # 2
|
||||
ECDSA: _KeyType.ValueType # 3
|
||||
ECC_P256: _KeyType.ValueType # 4
|
||||
X25519: _KeyType.ValueType # 5
|
||||
|
||||
class KeyType(_KeyType, metaclass=_KeyTypeEnumTypeWrapper): ...
|
||||
|
||||
@ -36,6 +37,7 @@ Ed25519: KeyType.ValueType # 1
|
||||
Secp256k1: KeyType.ValueType # 2
|
||||
ECDSA: KeyType.ValueType # 3
|
||||
ECC_P256: KeyType.ValueType # 4
|
||||
X25519: KeyType.ValueType # 5
|
||||
global___KeyType = KeyType
|
||||
|
||||
@typing.final
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: libp2p/identity/identify/pb/identify.proto
|
||||
# Protobuf Python Version: 4.25.3
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
@ -16,11 +15,11 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*libp2p/identity/identify/pb/identify.proto\x12\x0bidentify.pb\"\xa9\x01\n\x08Identify\x12\x18\n\x10protocol_version\x18\x05 \x01(\t\x12\x15\n\ragent_version\x18\x06 \x01(\t\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x14\n\x0clisten_addrs\x18\x02 \x03(\x0c\x12\x15\n\robserved_addr\x18\x04 \x01(\x0c\x12\x11\n\tprotocols\x18\x03 \x03(\t\x12\x18\n\x10signedPeerRecord\x18\x08 \x01(\x0c')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.identity.identify.pb.identify_pb2', _globals)
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.identity.identify.pb.identify_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_globals['_IDENTIFY']._serialized_start=60
|
||||
_globals['_IDENTIFY']._serialized_end=229
|
||||
_IDENTIFY._serialized_start=60
|
||||
_IDENTIFY._serialized_end=229
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@ -1,24 +1,49 @@
|
||||
from google.protobuf.internal import containers as _containers
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Optional as _Optional
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
DESCRIPTOR: _descriptor.FileDescriptor
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
import typing
|
||||
|
||||
class Identify(_message.Message):
|
||||
__slots__ = ("protocol_version", "agent_version", "public_key", "listen_addrs", "observed_addr", "protocols", "signedPeerRecord")
|
||||
PROTOCOL_VERSION_FIELD_NUMBER: _ClassVar[int]
|
||||
AGENT_VERSION_FIELD_NUMBER: _ClassVar[int]
|
||||
PUBLIC_KEY_FIELD_NUMBER: _ClassVar[int]
|
||||
LISTEN_ADDRS_FIELD_NUMBER: _ClassVar[int]
|
||||
OBSERVED_ADDR_FIELD_NUMBER: _ClassVar[int]
|
||||
PROTOCOLS_FIELD_NUMBER: _ClassVar[int]
|
||||
SIGNEDPEERRECORD_FIELD_NUMBER: _ClassVar[int]
|
||||
protocol_version: str
|
||||
agent_version: str
|
||||
public_key: bytes
|
||||
listen_addrs: _containers.RepeatedScalarFieldContainer[bytes]
|
||||
observed_addr: bytes
|
||||
protocols: _containers.RepeatedScalarFieldContainer[str]
|
||||
signedPeerRecord: bytes
|
||||
def __init__(self, protocol_version: _Optional[str] = ..., agent_version: _Optional[str] = ..., public_key: _Optional[bytes] = ..., listen_addrs: _Optional[_Iterable[bytes]] = ..., observed_addr: _Optional[bytes] = ..., protocols: _Optional[_Iterable[str]] = ..., signedPeerRecord: _Optional[bytes] = ...) -> None: ...
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class Identify(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
PROTOCOL_VERSION_FIELD_NUMBER: builtins.int
|
||||
AGENT_VERSION_FIELD_NUMBER: builtins.int
|
||||
PUBLIC_KEY_FIELD_NUMBER: builtins.int
|
||||
LISTEN_ADDRS_FIELD_NUMBER: builtins.int
|
||||
OBSERVED_ADDR_FIELD_NUMBER: builtins.int
|
||||
PROTOCOLS_FIELD_NUMBER: builtins.int
|
||||
SIGNEDPEERRECORD_FIELD_NUMBER: builtins.int
|
||||
protocol_version: builtins.str
|
||||
agent_version: builtins.str
|
||||
public_key: builtins.bytes
|
||||
observed_addr: builtins.bytes
|
||||
signedPeerRecord: builtins.bytes
|
||||
@property
|
||||
def listen_addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
|
||||
@property
|
||||
def protocols(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
protocol_version: builtins.str | None = ...,
|
||||
agent_version: builtins.str | None = ...,
|
||||
public_key: builtins.bytes | None = ...,
|
||||
listen_addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
|
||||
observed_addr: builtins.bytes | None = ...,
|
||||
protocols: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
signedPeerRecord: builtins.bytes | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["agent_version", b"agent_version", "observed_addr", b"observed_addr", "protocol_version", b"protocol_version", "public_key", b"public_key", "signedPeerRecord", b"signedPeerRecord"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["agent_version", b"agent_version", "listen_addrs", b"listen_addrs", "observed_addr", b"observed_addr", "protocol_version", b"protocol_version", "protocols", b"protocols", "public_key", b"public_key", "signedPeerRecord", b"signedPeerRecord"]) -> None: ...
|
||||
|
||||
global___Identify = Identify
|
||||
|
||||
@ -2,10 +2,10 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: libp2p/kad_dht/pb/kademlia.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
@ -15,19 +15,19 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals)
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_globals['_RECORD']._serialized_start=36
|
||||
_globals['_RECORD']._serialized_end=94
|
||||
_globals['_MESSAGE']._serialized_start=97
|
||||
_globals['_MESSAGE']._serialized_end=555
|
||||
_globals['_MESSAGE_PEER']._serialized_start=281
|
||||
_globals['_MESSAGE_PEER']._serialized_end=359
|
||||
_globals['_MESSAGE_MESSAGETYPE']._serialized_start=361
|
||||
_globals['_MESSAGE_MESSAGETYPE']._serialized_end=466
|
||||
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=468
|
||||
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=555
|
||||
_RECORD._serialized_start=36
|
||||
_RECORD._serialized_end=94
|
||||
_MESSAGE._serialized_start=97
|
||||
_MESSAGE._serialized_end=555
|
||||
_MESSAGE_PEER._serialized_start=281
|
||||
_MESSAGE_PEER._serialized_end=359
|
||||
_MESSAGE_MESSAGETYPE._serialized_start=361
|
||||
_MESSAGE_MESSAGETYPE._serialized_end=466
|
||||
_MESSAGE_CONNECTIONTYPE._serialized_start=468
|
||||
_MESSAGE_CONNECTIONTYPE._serialized_end=555
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
@ -147,6 +148,24 @@ class SwarmConn(INetConn):
|
||||
def get_streams(self) -> tuple[NetStream, ...]:
|
||||
return tuple(self.streams)
|
||||
|
||||
def get_transport_addresses(self) -> list[Multiaddr]:
|
||||
"""
|
||||
Retrieve the transport addresses used by this connection.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Multiaddr]
|
||||
A list of multiaddresses used by the transport.
|
||||
|
||||
"""
|
||||
# Return the addresses from the peerstore for this peer
|
||||
try:
|
||||
peer_id = self.muxed_conn.peer_id
|
||||
return self.swarm.peerstore.addrs(peer_id)
|
||||
except Exception as e:
|
||||
logging.warning(f"Error getting transport addresses: {e}")
|
||||
return []
|
||||
|
||||
def remove_stream(self, stream: NetStream) -> None:
|
||||
if stream not in self.streams:
|
||||
return
|
||||
|
||||
@ -39,8 +39,6 @@ from .peerinfo import (
|
||||
PERMANENT_ADDR_TTL = 0
|
||||
|
||||
|
||||
# TODO: Set up an async task for periodic peer-store cleanup
|
||||
# for expired addresses and records.
|
||||
class PeerRecordState:
|
||||
envelope: Envelope
|
||||
seq: int
|
||||
@ -217,7 +215,6 @@ class PeerStore(IPeerStore):
|
||||
|
||||
# -----CERT-ADDR-BOOK-----
|
||||
|
||||
# TODO: Make proper use of this function
|
||||
def maybe_delete_peer_record(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Delete the signed peer record for a peer if it has no know
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: rpc.proto
|
||||
# source: libp2p/pubsub/pb/rpc.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
@ -13,39 +13,39 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\trpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'rpc_pb2', globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_RPC._serialized_start=25
|
||||
_RPC._serialized_end=205
|
||||
_RPC_SUBOPTS._serialized_start=160
|
||||
_RPC_SUBOPTS._serialized_end=205
|
||||
_MESSAGE._serialized_start=207
|
||||
_MESSAGE._serialized_end=312
|
||||
_CONTROLMESSAGE._serialized_start=315
|
||||
_CONTROLMESSAGE._serialized_end=491
|
||||
_CONTROLIHAVE._serialized_start=493
|
||||
_CONTROLIHAVE._serialized_end=544
|
||||
_CONTROLIWANT._serialized_start=546
|
||||
_CONTROLIWANT._serialized_end=580
|
||||
_CONTROLGRAFT._serialized_start=582
|
||||
_CONTROLGRAFT._serialized_end=613
|
||||
_CONTROLPRUNE._serialized_start=615
|
||||
_CONTROLPRUNE._serialized_end=699
|
||||
_PEERINFO._serialized_start=701
|
||||
_PEERINFO._serialized_end=753
|
||||
_TOPICDESCRIPTOR._serialized_start=756
|
||||
_TOPICDESCRIPTOR._serialized_end=1147
|
||||
_TOPICDESCRIPTOR_AUTHOPTS._serialized_start=889
|
||||
_TOPICDESCRIPTOR_AUTHOPTS._serialized_end=1013
|
||||
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=975
|
||||
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=1013
|
||||
_TOPICDESCRIPTOR_ENCOPTS._serialized_start=1016
|
||||
_TOPICDESCRIPTOR_ENCOPTS._serialized_end=1147
|
||||
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1104
|
||||
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1147
|
||||
_RPC._serialized_start=42
|
||||
_RPC._serialized_end=222
|
||||
_RPC_SUBOPTS._serialized_start=177
|
||||
_RPC_SUBOPTS._serialized_end=222
|
||||
_MESSAGE._serialized_start=224
|
||||
_MESSAGE._serialized_end=329
|
||||
_CONTROLMESSAGE._serialized_start=332
|
||||
_CONTROLMESSAGE._serialized_end=508
|
||||
_CONTROLIHAVE._serialized_start=510
|
||||
_CONTROLIHAVE._serialized_end=561
|
||||
_CONTROLIWANT._serialized_start=563
|
||||
_CONTROLIWANT._serialized_end=597
|
||||
_CONTROLGRAFT._serialized_start=599
|
||||
_CONTROLGRAFT._serialized_end=630
|
||||
_CONTROLPRUNE._serialized_start=632
|
||||
_CONTROLPRUNE._serialized_end=716
|
||||
_PEERINFO._serialized_start=718
|
||||
_PEERINFO._serialized_end=770
|
||||
_TOPICDESCRIPTOR._serialized_start=773
|
||||
_TOPICDESCRIPTOR._serialized_end=1164
|
||||
_TOPICDESCRIPTOR_AUTHOPTS._serialized_start=906
|
||||
_TOPICDESCRIPTOR_AUTHOPTS._serialized_end=1030
|
||||
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=992
|
||||
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=1030
|
||||
_TOPICDESCRIPTOR_ENCOPTS._serialized_start=1033
|
||||
_TOPICDESCRIPTOR_ENCOPTS._serialized_end=1164
|
||||
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1121
|
||||
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1164
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@ -15,6 +15,10 @@ from libp2p.relay.circuit_v2 import (
|
||||
RelayLimits,
|
||||
RelayResourceManager,
|
||||
Reservation,
|
||||
DCUTR_PROTOCOL_ID,
|
||||
DCUtRProtocol,
|
||||
ReachabilityChecker,
|
||||
is_private_ip,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -25,4 +29,9 @@ __all__ = [
|
||||
"RelayLimits",
|
||||
"RelayResourceManager",
|
||||
"Reservation",
|
||||
"DCUtRProtocol",
|
||||
"DCUTR_PROTOCOL_ID",
|
||||
"ReachabilityChecker",
|
||||
"is_private_ip"
|
||||
|
||||
]
|
||||
|
||||
@ -5,6 +5,16 @@ This package implements the Circuit Relay v2 protocol as specified in:
|
||||
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
|
||||
"""
|
||||
|
||||
from .dcutr import (
|
||||
DCUtRProtocol,
|
||||
)
|
||||
from .dcutr import PROTOCOL_ID as DCUTR_PROTOCOL_ID
|
||||
|
||||
from .nat import (
|
||||
ReachabilityChecker,
|
||||
is_private_ip,
|
||||
)
|
||||
|
||||
from .discovery import (
|
||||
RelayDiscovery,
|
||||
)
|
||||
@ -29,4 +39,8 @@ __all__ = [
|
||||
"RelayResourceManager",
|
||||
"CircuitV2Transport",
|
||||
"RelayDiscovery",
|
||||
"DCUtRProtocol",
|
||||
"DCUTR_PROTOCOL_ID",
|
||||
"ReachabilityChecker",
|
||||
"is_private_ip",
|
||||
]
|
||||
|
||||
580
libp2p/relay/circuit_v2/dcutr.py
Normal file
580
libp2p/relay/circuit_v2/dcutr.py
Normal file
@ -0,0 +1,580 @@
|
||||
"""
|
||||
Direct Connection Upgrade through Relay (DCUtR) protocol implementation.
|
||||
|
||||
This module implements the DCUtR protocol as specified in:
|
||||
https://github.com/libp2p/specs/blob/master/relay/DCUtR.md
|
||||
|
||||
DCUtR enables peers behind NAT to establish direct connections
|
||||
using hole punching techniques.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
INetConn,
|
||||
INetStream,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.nat import (
|
||||
ReachabilityChecker,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.pb.dcutr_pb2 import (
|
||||
HolePunch,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Protocol ID for DCUtR
|
||||
PROTOCOL_ID = TProtocol("/libp2p/dcutr")
|
||||
|
||||
# Maximum message size for DCUtR (4KiB as per spec)
|
||||
MAX_MESSAGE_SIZE = 4 * 1024
|
||||
|
||||
# Timeouts
|
||||
STREAM_READ_TIMEOUT = 30 # seconds
|
||||
STREAM_WRITE_TIMEOUT = 30 # seconds
|
||||
DIAL_TIMEOUT = 10 # seconds
|
||||
|
||||
# Maximum number of hole punch attempts per peer
|
||||
MAX_HOLE_PUNCH_ATTEMPTS = 5
|
||||
|
||||
# Delay between retry attempts
|
||||
HOLE_PUNCH_RETRY_DELAY = 30 # seconds
|
||||
|
||||
# Maximum observed addresses to exchange
|
||||
MAX_OBSERVED_ADDRS = 20
|
||||
|
||||
|
||||
class DCUtRProtocol(Service):
|
||||
"""
|
||||
DCUtRProtocol implements the Direct Connection Upgrade through Relay protocol.
|
||||
|
||||
This protocol allows two NATed peers to establish direct connections through
|
||||
hole punching, after they have established an initial connection through a relay.
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost):
|
||||
"""
|
||||
Initialize the DCUtR protocol.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The libp2p host this protocol is running on
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.host = host
|
||||
self.event_started = trio.Event()
|
||||
self._hole_punch_attempts: dict[ID, int] = {}
|
||||
self._direct_connections: set[ID] = set()
|
||||
self._in_progress: set[ID] = set()
|
||||
self._reachability_checker = ReachabilityChecker(host)
|
||||
self._nursery: trio.Nursery | None = None
|
||||
|
||||
async def run(self, *, task_status: Any = trio.TASK_STATUS_IGNORED) -> None:
|
||||
"""Run the protocol service."""
|
||||
try:
|
||||
# Register the DCUtR protocol handler
|
||||
logger.debug("Registering DCUtR protocol handler")
|
||||
self.host.set_stream_handler(PROTOCOL_ID, self._handle_dcutr_stream)
|
||||
|
||||
# Signal that we're ready
|
||||
self.event_started.set()
|
||||
|
||||
# Start the service
|
||||
async with trio.open_nursery() as nursery:
|
||||
self._nursery = nursery
|
||||
task_status.started()
|
||||
logger.debug("DCUtR protocol service started")
|
||||
|
||||
# Wait for service to be stopped
|
||||
await self.manager.wait_finished()
|
||||
finally:
|
||||
# Clean up
|
||||
try:
|
||||
# Use empty async lambda instead of None for stream handler
|
||||
async def empty_handler(_: INetStream) -> None:
|
||||
pass
|
||||
|
||||
self.host.set_stream_handler(PROTOCOL_ID, empty_handler)
|
||||
logger.debug("DCUtR protocol handler unregistered")
|
||||
except Exception as e:
|
||||
logger.error("Error unregistering DCUtR protocol handler: %s", str(e))
|
||||
|
||||
# Clear state
|
||||
self._hole_punch_attempts.clear()
|
||||
self._direct_connections.clear()
|
||||
self._in_progress.clear()
|
||||
self._nursery = None
|
||||
|
||||
async def _handle_dcutr_stream(self, stream: INetStream) -> None:
|
||||
"""
|
||||
Handle incoming DCUtR streams.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stream : INetStream
|
||||
The incoming stream
|
||||
|
||||
"""
|
||||
try:
|
||||
# Get the remote peer ID
|
||||
remote_peer_id = stream.muxed_conn.peer_id
|
||||
logger.debug("Received DCUtR stream from peer %s", remote_peer_id)
|
||||
|
||||
# Check if we already have a direct connection
|
||||
if await self._have_direct_connection(remote_peer_id):
|
||||
logger.debug(
|
||||
"Already have direct connection to %s, closing stream",
|
||||
remote_peer_id,
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
# Check if there's already an active hole punch attempt
|
||||
if remote_peer_id in self._in_progress:
|
||||
logger.debug("Hole punch already in progress with %s", remote_peer_id)
|
||||
# Let the existing attempt continue
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
# Mark as in progress
|
||||
self._in_progress.add(remote_peer_id)
|
||||
|
||||
try:
|
||||
# Read the CONNECT message
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
msg_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
||||
|
||||
# Parse the message
|
||||
connect_msg = HolePunch()
|
||||
connect_msg.ParseFromString(msg_bytes)
|
||||
|
||||
# Verify it's a CONNECT message
|
||||
if connect_msg.type != HolePunch.CONNECT:
|
||||
logger.warning("Expected CONNECT message, got %s", connect_msg.type)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
"Received CONNECT message from %s with %d addresses",
|
||||
remote_peer_id,
|
||||
len(connect_msg.ObsAddrs),
|
||||
)
|
||||
|
||||
# Process observed addresses from the peer
|
||||
peer_addrs = self._decode_observed_addrs(list(connect_msg.ObsAddrs))
|
||||
logger.debug("Decoded %d valid addresses from peer", len(peer_addrs))
|
||||
|
||||
# Store the addresses in the peerstore
|
||||
if peer_addrs:
|
||||
self.host.get_peerstore().add_addrs(
|
||||
remote_peer_id, peer_addrs, 10 * 60
|
||||
) # 10 minute TTL
|
||||
|
||||
# Send our CONNECT message with our observed addresses
|
||||
our_addrs = await self._get_observed_addrs()
|
||||
response = HolePunch()
|
||||
response.type = HolePunch.CONNECT
|
||||
response.ObsAddrs.extend(our_addrs)
|
||||
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
||||
await stream.write(response.SerializeToString())
|
||||
|
||||
logger.debug(
|
||||
"Sent CONNECT response to %s with %d addresses",
|
||||
remote_peer_id,
|
||||
len(our_addrs),
|
||||
)
|
||||
|
||||
# Wait for SYNC message
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
sync_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
||||
|
||||
# Parse the SYNC message
|
||||
sync_msg = HolePunch()
|
||||
sync_msg.ParseFromString(sync_bytes)
|
||||
|
||||
# Verify it's a SYNC message
|
||||
if sync_msg.type != HolePunch.SYNC:
|
||||
logger.warning("Expected SYNC message, got %s", sync_msg.type)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
logger.debug("Received SYNC message from %s", remote_peer_id)
|
||||
|
||||
# Perform hole punch
|
||||
success = await self._perform_hole_punch(remote_peer_id, peer_addrs)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
"Successfully established direct connection with %s",
|
||||
remote_peer_id,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to establish direct connection with %s", remote_peer_id
|
||||
)
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.warning("Timeout in DCUtR protocol with peer %s", remote_peer_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error in DCUtR protocol with peer %s: %s", remote_peer_id, str(e)
|
||||
)
|
||||
finally:
|
||||
# Clean up
|
||||
self._in_progress.discard(remote_peer_id)
|
||||
await stream.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling DCUtR stream: %s", str(e))
|
||||
await stream.close()
|
||||
|
||||
async def initiate_hole_punch(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Initiate a hole punch with a peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer to hole punch with
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if hole punch was successful, False otherwise
|
||||
|
||||
"""
|
||||
# Check if we already have a direct connection
|
||||
if await self._have_direct_connection(peer_id):
|
||||
logger.debug("Already have direct connection to %s", peer_id)
|
||||
return True
|
||||
|
||||
# Check if there's already an active hole punch attempt
|
||||
if peer_id in self._in_progress:
|
||||
logger.debug("Hole punch already in progress with %s", peer_id)
|
||||
return False
|
||||
|
||||
# Check if we've exceeded the maximum number of attempts
|
||||
attempts = self._hole_punch_attempts.get(peer_id, 0)
|
||||
if attempts >= MAX_HOLE_PUNCH_ATTEMPTS:
|
||||
logger.warning("Maximum hole punch attempts reached for peer %s", peer_id)
|
||||
return False
|
||||
|
||||
# Mark as in progress and increment attempt counter
|
||||
self._in_progress.add(peer_id)
|
||||
self._hole_punch_attempts[peer_id] = attempts + 1
|
||||
|
||||
try:
|
||||
# Open a DCUtR stream to the peer
|
||||
logger.debug("Opening DCUtR stream to peer %s", peer_id)
|
||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||
if not stream:
|
||||
logger.warning("Failed to open DCUtR stream to peer %s", peer_id)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Send our CONNECT message with our observed addresses
|
||||
our_addrs = await self._get_observed_addrs()
|
||||
connect_msg = HolePunch()
|
||||
connect_msg.type = HolePunch.CONNECT
|
||||
connect_msg.ObsAddrs.extend(our_addrs)
|
||||
|
||||
start_time = time.time()
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
||||
await stream.write(connect_msg.SerializeToString())
|
||||
|
||||
logger.debug(
|
||||
"Sent CONNECT message to %s with %d addresses",
|
||||
peer_id,
|
||||
len(our_addrs),
|
||||
)
|
||||
|
||||
# Receive the peer's CONNECT message
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
resp_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
||||
|
||||
# Calculate RTT
|
||||
rtt = time.time() - start_time
|
||||
|
||||
# Parse the response
|
||||
resp = HolePunch()
|
||||
resp.ParseFromString(resp_bytes)
|
||||
|
||||
# Verify it's a CONNECT message
|
||||
if resp.type != HolePunch.CONNECT:
|
||||
logger.warning("Expected CONNECT message, got %s", resp.type)
|
||||
return False
|
||||
|
||||
logger.debug(
|
||||
"Received CONNECT response from %s with %d addresses",
|
||||
peer_id,
|
||||
len(resp.ObsAddrs),
|
||||
)
|
||||
|
||||
# Process observed addresses from the peer
|
||||
peer_addrs = self._decode_observed_addrs(list(resp.ObsAddrs))
|
||||
logger.debug("Decoded %d valid addresses from peer", len(peer_addrs))
|
||||
|
||||
# Store the addresses in the peerstore
|
||||
if peer_addrs:
|
||||
self.host.get_peerstore().add_addrs(
|
||||
peer_id, peer_addrs, 10 * 60
|
||||
) # 10 minute TTL
|
||||
|
||||
# Send SYNC message with timing information
|
||||
# We'll use a future time that's 2*RTT from now to ensure both sides
|
||||
# are ready
|
||||
punch_time = time.time() + (2 * rtt) + 1 # Add 1 second buffer
|
||||
|
||||
sync_msg = HolePunch()
|
||||
sync_msg.type = HolePunch.SYNC
|
||||
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
||||
await stream.write(sync_msg.SerializeToString())
|
||||
|
||||
logger.debug("Sent SYNC message to %s", peer_id)
|
||||
|
||||
# Perform the synchronized hole punch
|
||||
success = await self._perform_hole_punch(
|
||||
peer_id, peer_addrs, punch_time
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
"Successfully established direct connection with %s", peer_id
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to establish direct connection with %s", peer_id
|
||||
)
|
||||
return False
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.warning("Timeout in DCUtR protocol with peer %s", peer_id)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error in DCUtR protocol with peer %s: %s", peer_id, str(e)
|
||||
)
|
||||
return False
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error initiating hole punch with peer %s: %s", peer_id, str(e)
|
||||
)
|
||||
return False
|
||||
finally:
|
||||
self._in_progress.discard(peer_id)
|
||||
|
||||
# This should never be reached, but add explicit return for type checking
|
||||
return False
|
||||
|
||||
async def _perform_hole_punch(
|
||||
self, peer_id: ID, addrs: list[Multiaddr], punch_time: float | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Perform a hole punch attempt with a peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer to hole punch with
|
||||
addrs : list[Multiaddr]
|
||||
List of addresses to try
|
||||
punch_time : Optional[float]
|
||||
Time to perform the punch (if None, do it immediately)
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if hole punch was successful
|
||||
|
||||
"""
|
||||
if not addrs:
|
||||
logger.warning("No addresses to try for hole punch with %s", peer_id)
|
||||
return False
|
||||
|
||||
# If punch_time is specified, wait until that time
|
||||
if punch_time is not None:
|
||||
now = time.time()
|
||||
if punch_time > now:
|
||||
wait_time = punch_time - now
|
||||
logger.debug("Waiting %.2f seconds before hole punch", wait_time)
|
||||
await trio.sleep(wait_time)
|
||||
|
||||
# Try to dial each address
|
||||
logger.debug(
|
||||
"Starting hole punch with peer %s using %d addresses", peer_id, len(addrs)
|
||||
)
|
||||
|
||||
# Filter to only include non-relay addresses
|
||||
direct_addrs = [
|
||||
addr for addr in addrs if not str(addr).startswith("/p2p-circuit")
|
||||
]
|
||||
|
||||
if not direct_addrs:
|
||||
logger.warning("No direct addresses found for peer %s", peer_id)
|
||||
return False
|
||||
|
||||
# Start dialing attempts in parallel
|
||||
async with trio.open_nursery() as nursery:
|
||||
for addr in direct_addrs[
|
||||
:5
|
||||
]: # Limit to 5 addresses to avoid too many connections
|
||||
nursery.start_soon(self._dial_peer, peer_id, addr)
|
||||
|
||||
# Check if we established a direct connection
|
||||
return await self._have_direct_connection(peer_id)
|
||||
|
||||
async def _dial_peer(self, peer_id: ID, addr: Multiaddr) -> None:
|
||||
"""
|
||||
Attempt to dial a peer at a specific address.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer to dial
|
||||
addr : Multiaddr
|
||||
The address to dial
|
||||
|
||||
"""
|
||||
try:
|
||||
logger.debug("Attempting to dial %s at %s", peer_id, addr)
|
||||
|
||||
# Create peer info
|
||||
peer_info = PeerInfo(peer_id, [addr])
|
||||
|
||||
# Try to connect with timeout
|
||||
with trio.fail_after(DIAL_TIMEOUT):
|
||||
await self.host.connect(peer_info)
|
||||
|
||||
logger.info("Successfully connected to %s at %s", peer_id, addr)
|
||||
|
||||
# Add to direct connections set
|
||||
self._direct_connections.add(peer_id)
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.debug("Timeout dialing %s at %s", peer_id, addr)
|
||||
except Exception as e:
|
||||
logger.debug("Error dialing %s at %s: %s", peer_id, addr, str(e))
|
||||
|
||||
async def _have_direct_connection(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if we already have a direct connection to a peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if we have a direct connection, False otherwise
|
||||
|
||||
"""
|
||||
# Check our direct connections cache first
|
||||
if peer_id in self._direct_connections:
|
||||
return True
|
||||
|
||||
# Check if the peer is connected
|
||||
network = self.host.get_network()
|
||||
conn_or_conns = network.connections.get(peer_id)
|
||||
if not conn_or_conns:
|
||||
return False
|
||||
|
||||
# Handle both single connection and list of connections
|
||||
connections: list[INetConn] = (
|
||||
[conn_or_conns] if not isinstance(conn_or_conns, list) else conn_or_conns
|
||||
)
|
||||
|
||||
# Check if any connection is direct (not relayed)
|
||||
for conn in connections:
|
||||
# Get the transport addresses
|
||||
addrs = conn.get_transport_addresses()
|
||||
|
||||
# If any address doesn't start with /p2p-circuit, it's a direct connection
|
||||
if any(not str(addr).startswith("/p2p-circuit") for addr in addrs):
|
||||
# Cache this result
|
||||
self._direct_connections.add(peer_id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _get_observed_addrs(self) -> list[bytes]:
|
||||
"""
|
||||
Get our observed addresses to share with the peer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[bytes]
|
||||
List of observed addresses as bytes
|
||||
|
||||
"""
|
||||
# Get all listen addresses
|
||||
addrs = self.host.get_addrs()
|
||||
|
||||
# Filter out relay addresses
|
||||
direct_addrs = [
|
||||
addr for addr in addrs if not str(addr).startswith("/p2p-circuit")
|
||||
]
|
||||
|
||||
# Limit the number of addresses
|
||||
if len(direct_addrs) > MAX_OBSERVED_ADDRS:
|
||||
direct_addrs = direct_addrs[:MAX_OBSERVED_ADDRS]
|
||||
|
||||
# Convert to bytes
|
||||
addr_bytes = [addr.to_bytes() for addr in direct_addrs]
|
||||
|
||||
return addr_bytes
|
||||
|
||||
def _decode_observed_addrs(self, addr_bytes: list[bytes]) -> list[Multiaddr]:
|
||||
"""
|
||||
Decode observed addresses received from a peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
addr_bytes : List[bytes]
|
||||
The encoded addresses
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Multiaddr]
|
||||
The decoded multiaddresses
|
||||
|
||||
"""
|
||||
result = []
|
||||
|
||||
for addr_byte in addr_bytes:
|
||||
try:
|
||||
addr = Multiaddr(addr_byte)
|
||||
# Validate the address (basic check)
|
||||
if str(addr).startswith("/ip"):
|
||||
result.append(addr)
|
||||
except Exception as e:
|
||||
logger.debug("Error decoding multiaddr: %s", str(e))
|
||||
|
||||
return result
|
||||
300
libp2p/relay/circuit_v2/nat.py
Normal file
300
libp2p/relay/circuit_v2/nat.py
Normal file
@ -0,0 +1,300 @@
|
||||
"""
|
||||
NAT traversal utilities for libp2p.
|
||||
|
||||
This module provides utilities for NAT traversal and reachability detection.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
INetConn,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.relay.circuit_v2.nat")
|
||||
|
||||
# Timeout for reachability checks
|
||||
REACHABILITY_TIMEOUT = 10 # seconds
|
||||
|
||||
# Define private IP ranges
|
||||
PRIVATE_IP_RANGES = [
|
||||
("10.0.0.0", "10.255.255.255"), # Class A private network: 10.0.0.0/8
|
||||
("172.16.0.0", "172.31.255.255"), # Class B private network: 172.16.0.0/12
|
||||
("192.168.0.0", "192.168.255.255"), # Class C private network: 192.168.0.0/16
|
||||
]
|
||||
|
||||
# Link-local address range: 169.254.0.0/16
|
||||
LINK_LOCAL_RANGE = ("169.254.0.0", "169.254.255.255")
|
||||
|
||||
# Loopback address range: 127.0.0.0/8
|
||||
LOOPBACK_RANGE = ("127.0.0.0", "127.255.255.255")
|
||||
|
||||
|
||||
def ip_to_int(ip: str) -> int:
|
||||
"""
|
||||
Convert an IP address to an integer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ip : str
|
||||
IP address to convert
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Integer representation of the IP
|
||||
|
||||
"""
|
||||
try:
|
||||
return int(ipaddress.IPv4Address(ip))
|
||||
except ipaddress.AddressValueError:
|
||||
# Handle IPv6 addresses
|
||||
return int(ipaddress.IPv6Address(ip))
|
||||
|
||||
|
||||
def is_ip_in_range(ip: str, start_range: str, end_range: str) -> bool:
|
||||
"""
|
||||
Check if an IP address is within a range.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ip : str
|
||||
IP address to check
|
||||
start_range : str
|
||||
Start of the range
|
||||
end_range : str
|
||||
End of the range
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the IP is in the range
|
||||
|
||||
"""
|
||||
try:
|
||||
ip_int = ip_to_int(ip)
|
||||
start_int = ip_to_int(start_range)
|
||||
end_int = ip_to_int(end_range)
|
||||
return start_int <= ip_int <= end_int
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_private_ip(ip: str) -> bool:
|
||||
"""
|
||||
Check if an IP address is private.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ip : str
|
||||
IP address to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if IP is private
|
||||
|
||||
"""
|
||||
for start_range, end_range in PRIVATE_IP_RANGES:
|
||||
if is_ip_in_range(ip, start_range, end_range):
|
||||
return True
|
||||
|
||||
# Check for link-local addresses
|
||||
if is_ip_in_range(ip, *LINK_LOCAL_RANGE):
|
||||
return True
|
||||
|
||||
# Check for loopback addresses
|
||||
if is_ip_in_range(ip, *LOOPBACK_RANGE):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def extract_ip_from_multiaddr(addr: Multiaddr) -> str | None:
|
||||
"""
|
||||
Extract the IP address from a multiaddr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
addr : Multiaddr
|
||||
Multiaddr to extract from
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[str]
|
||||
IP address or None if not found
|
||||
|
||||
"""
|
||||
# Convert to string representation
|
||||
addr_str = str(addr)
|
||||
|
||||
# Look for IPv4 address
|
||||
ipv4_start = addr_str.find("/ip4/")
|
||||
if ipv4_start != -1:
|
||||
# Extract the IPv4 address
|
||||
ipv4_end = addr_str.find("/", ipv4_start + 5)
|
||||
if ipv4_end != -1:
|
||||
return addr_str[ipv4_start + 5 : ipv4_end]
|
||||
|
||||
# Look for IPv6 address
|
||||
ipv6_start = addr_str.find("/ip6/")
|
||||
if ipv6_start != -1:
|
||||
# Extract the IPv6 address
|
||||
ipv6_end = addr_str.find("/", ipv6_start + 5)
|
||||
if ipv6_end != -1:
|
||||
return addr_str[ipv6_start + 5 : ipv6_end]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ReachabilityChecker:
|
||||
"""
|
||||
Utility class for checking peer reachability.
|
||||
|
||||
This class assesses whether a peer's addresses are likely
|
||||
to be directly reachable or behind NAT.
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost):
|
||||
"""
|
||||
Initialize the reachability checker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The libp2p host
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self._peer_reachability: dict[ID, bool] = {}
|
||||
self._known_public_peers: set[ID] = set()
|
||||
|
||||
def is_addr_public(self, addr: Multiaddr) -> bool:
|
||||
"""
|
||||
Check if an address is likely to be publicly reachable.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
addr : Multiaddr
|
||||
The multiaddr to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if address is likely public
|
||||
|
||||
"""
|
||||
# Extract the IP address
|
||||
ip = extract_ip_from_multiaddr(addr)
|
||||
if not ip:
|
||||
return False
|
||||
|
||||
# Check if it's a private IP
|
||||
return not is_private_ip(ip)
|
||||
|
||||
def get_public_addrs(self, addrs: list[Multiaddr]) -> list[Multiaddr]:
|
||||
"""
|
||||
Filter a list of addresses to only include likely public ones.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
addrs : List[Multiaddr]
|
||||
List of addresses to filter
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Multiaddr]
|
||||
List of likely public addresses
|
||||
|
||||
"""
|
||||
return [addr for addr in addrs if self.is_addr_public(addr)]
|
||||
|
||||
async def check_peer_reachability(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if a peer is directly reachable.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if peer is likely directly reachable
|
||||
|
||||
"""
|
||||
# Check if we already know
|
||||
if peer_id in self._peer_reachability:
|
||||
return self._peer_reachability[peer_id]
|
||||
|
||||
# Check if the peer is connected
|
||||
network = self.host.get_network()
|
||||
connections: INetConn | list[INetConn] | None = network.connections.get(peer_id)
|
||||
if not connections:
|
||||
# Not connected, can't determine reachability
|
||||
return False
|
||||
|
||||
# Check if any connection is direct (not relayed)
|
||||
if isinstance(connections, list):
|
||||
for conn in connections:
|
||||
# Get the transport addresses
|
||||
addrs = conn.get_transport_addresses()
|
||||
|
||||
# If any address doesn't start with /p2p-circuit,
|
||||
# it's a direct connection
|
||||
if any(not str(addr).startswith("/p2p-circuit") for addr in addrs):
|
||||
self._peer_reachability[peer_id] = True
|
||||
return True
|
||||
else:
|
||||
# Handle single connection case
|
||||
addrs = connections.get_transport_addresses()
|
||||
if any(not str(addr).startswith("/p2p-circuit") for addr in addrs):
|
||||
self._peer_reachability[peer_id] = True
|
||||
return True
|
||||
|
||||
# Get the peer's addresses from peerstore
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
# Check if peer has any public addresses
|
||||
public_addrs = self.get_public_addrs(addrs)
|
||||
if public_addrs:
|
||||
self._peer_reachability[peer_id] = True
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("Error getting peer addresses: %s", str(e))
|
||||
|
||||
# Default to not directly reachable
|
||||
self._peer_reachability[peer_id] = False
|
||||
return False
|
||||
|
||||
async def check_self_reachability(self) -> tuple[bool, list[Multiaddr]]:
|
||||
"""
|
||||
Check if this host is likely directly reachable.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[bool, List[Multiaddr]]
|
||||
Tuple of (is_reachable, public_addresses)
|
||||
|
||||
"""
|
||||
# Get all host addresses
|
||||
addrs = self.host.get_addrs()
|
||||
|
||||
# Filter for public addresses
|
||||
public_addrs = self.get_public_addrs(addrs)
|
||||
|
||||
# If we have public addresses, assume we're reachable
|
||||
# This is a simplified assumption - real reachability would need
|
||||
# external checking
|
||||
is_reachable = len(public_addrs) > 0
|
||||
|
||||
return is_reachable, public_addrs
|
||||
@ -5,6 +5,11 @@ Contains generated protobuf code for circuit_v2 relay protocol.
|
||||
"""
|
||||
|
||||
# Import the classes to be accessible directly from the package
|
||||
|
||||
from .dcutr_pb2 import (
|
||||
HolePunch,
|
||||
)
|
||||
|
||||
from .circuit_pb2 import (
|
||||
HopMessage,
|
||||
Limit,
|
||||
@ -13,4 +18,4 @@ from .circuit_pb2 import (
|
||||
StopMessage,
|
||||
)
|
||||
|
||||
__all__ = ["HopMessage", "Limit", "Reservation", "Status", "StopMessage"]
|
||||
__all__ = ["HopMessage", "Limit", "Reservation", "Status", "StopMessage", "HolePunch"]
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: libp2p/relay/circuit_v2/pb/circuit.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
@ -12,11 +11,14 @@ from google.protobuf import symbol_database as _symbol_database
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(libp2p/relay/circuit_v2/pb/circuit.proto\x12\rcircuit.pb.v2\"\xf3\x01\n\nHopMessage\x12,\n\x04type\x18\x01 \x01(\x0e\x32\x1e.circuit.pb.v2.HopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12/\n\x0breservation\x18\x03 \x01(\x0b\x32\x1a.circuit.pb.v2.Reservation\x12#\n\x05limit\x18\x04 \x01(\x0b\x32\x14.circuit.pb.v2.Limit\x12%\n\x06status\x18\x05 \x01(\x0b\x32\x15.circuit.pb.v2.Status\",\n\x04Type\x12\x0b\n\x07RESERVE\x10\x00\x12\x0b\n\x07\x43ONNECT\x10\x01\x12\n\n\x06STATUS\x10\x02\"\x92\x01\n\x0bStopMessage\x12-\n\x04type\x18\x01 \x01(\x0e\x32\x1f.circuit.pb.v2.StopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12%\n\x06status\x18\x03 \x01(\x0b\x32\x15.circuit.pb.v2.Status\"\x1f\n\x04Type\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\n\n\x06STATUS\x10\x01\"A\n\x0bReservation\x12\x0f\n\x07voucher\x18\x01 \x01(\x0c\x12\x11\n\tsignature\x18\x02 \x01(\x0c\x12\x0e\n\x06\x65xpire\x18\x03 \x01(\x03\"\'\n\x05Limit\x12\x10\n\x08\x64uration\x18\x01 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x03\"\xf6\x01\n\x06Status\x12(\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1a.circuit.pb.v2.Status.Code\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xb0\x01\n\x04\x43ode\x12\x06\n\x02OK\x10\x00\x12\x17\n\x13RESERVATION_REFUSED\x10\x64\x12\x1b\n\x17RESOURCE_LIMIT_EXCEEDED\x10\x65\x12\x15\n\x11PERMISSION_DENIED\x10\x66\x12\x16\n\x11\x43ONNECTION_FAILED\x10\xc8\x01\x12\x11\n\x0c\x44IAL_REFUSED\x10\xc9\x01\x12\x10\n\x0bSTOP_FAILED\x10\xac\x02\x12\x16\n\x11MALFORMED_MESSAGE\x10\x90\x03\x62\x06proto3')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.relay.circuit_v2.pb.circuit_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_HOPMESSAGE._serialized_start=60
|
||||
_HOPMESSAGE._serialized_end=303
|
||||
|
||||
14
libp2p/relay/circuit_v2/pb/dcutr.proto
Normal file
14
libp2p/relay/circuit_v2/pb/dcutr.proto
Normal file
@ -0,0 +1,14 @@
|
||||
syntax = "proto2";
|
||||
|
||||
package holepunch.pb;
|
||||
|
||||
message HolePunch {
|
||||
enum Type {
|
||||
CONNECT = 100;
|
||||
SYNC = 300;
|
||||
}
|
||||
|
||||
required Type type = 1;
|
||||
|
||||
repeated bytes ObsAddrs = 2;
|
||||
}
|
||||
27
libp2p/relay/circuit_v2/pb/dcutr_pb2.py
Normal file
27
libp2p/relay/circuit_v2/pb/dcutr_pb2.py
Normal file
@ -0,0 +1,27 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: libp2p/relay/circuit_v2/pb/dcutr.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&libp2p/relay/circuit_v2/pb/dcutr.proto\x12\x0cholepunch.pb\"i\n\tHolePunch\x12*\n\x04type\x18\x01 \x02(\x0e\x32\x1c.holepunch.pb.HolePunch.Type\x12\x10\n\x08ObsAddrs\x18\x02 \x03(\x0c\"\x1e\n\x04Type\x12\x0b\n\x07\x43ONNECT\x10\x64\x12\t\n\x04SYNC\x10\xac\x02')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.relay.circuit_v2.pb.dcutr_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_HOLEPUNCH._serialized_start=56
|
||||
_HOLEPUNCH._serialized_end=161
|
||||
_HOLEPUNCH_TYPE._serialized_start=131
|
||||
_HOLEPUNCH_TYPE._serialized_end=161
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
53
libp2p/relay/circuit_v2/pb/dcutr_pb2.pyi
Normal file
53
libp2p/relay/circuit_v2/pb/dcutr_pb2.pyi
Normal file
@ -0,0 +1,53 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class HolePunch(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _Type:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _TypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[HolePunch._Type.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
CONNECT: HolePunch._Type.ValueType # 100
|
||||
SYNC: HolePunch._Type.ValueType # 300
|
||||
|
||||
class Type(_Type, metaclass=_TypeEnumTypeWrapper): ...
|
||||
CONNECT: HolePunch.Type.ValueType # 100
|
||||
SYNC: HolePunch.Type.ValueType # 300
|
||||
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
OBSADDRS_FIELD_NUMBER: builtins.int
|
||||
type: global___HolePunch.Type.ValueType
|
||||
@property
|
||||
def ObsAddrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
type: global___HolePunch.Type.ValueType | None = ...,
|
||||
ObsAddrs: collections.abc.Iterable[builtins.bytes] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["type", b"type"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["ObsAddrs", b"ObsAddrs", "type", b"type"]) -> None: ...
|
||||
|
||||
global___HolePunch = HolePunch
|
||||
@ -41,7 +41,8 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
|
||||
read_writer: NoisePacketReadWriter
|
||||
noise_state: NoiseState
|
||||
|
||||
# FIXME: This prefix is added in msg#3 in Go. Check whether it's a desired behavior.
|
||||
# NOTE: This prefix is added in msg#3 in Go.
|
||||
# Support in py-libp2p is available but not used
|
||||
prefix: bytes = b"\x00" * 32
|
||||
|
||||
def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None:
|
||||
|
||||
@ -29,11 +29,6 @@ class Transport(ISecureTransport):
|
||||
early_data: bytes | None
|
||||
with_noise_pipes: bool
|
||||
|
||||
# NOTE: Implementations that support Noise Pipes must decide whether to use
|
||||
# an XX or IK handshake based on whether they possess a cached static
|
||||
# Noise key for the remote peer.
|
||||
# TODO: A storage of seen noise static keys for pattern IK?
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
libp2p_keypair: KeyPair,
|
||||
|
||||
1
newsfragments/592.internal.rst
Normal file
1
newsfragments/592.internal.rst
Normal file
@ -0,0 +1 @@
|
||||
remove FIXME comment since it's obsolete and 32-byte prefix support is there but not enabled by default
|
||||
1
newsfragments/816.internal.rst
Normal file
1
newsfragments/816.internal.rst
Normal file
@ -0,0 +1 @@
|
||||
The TODO IK patterns in Noise has been deprecated in specs: https://github.com/libp2p/specs/tree/master/noise#handshake-pattern
|
||||
1
newsfragments/818.bugfix.rst
Normal file
1
newsfragments/818.bugfix.rst
Normal file
@ -0,0 +1 @@
|
||||
Recompiled protobufs that were out of date and added a `make` rule so that protobufs are always up to date.
|
||||
3
newsfragments/819.internal.rst
Normal file
3
newsfragments/819.internal.rst
Normal file
@ -0,0 +1,3 @@
|
||||
Remove the already completed TODO tasks in Peerstore:
|
||||
TODO: Set up an async task for periodic peer-store cleanup for expired addresses and records.
|
||||
TODO: Make proper use of this function
|
||||
563
tests/core/relay/test_dcutr_integration.py
Normal file
563
tests/core/relay/test_dcutr_integration.py
Normal file
@ -0,0 +1,563 @@
|
||||
"""Integration tests for DCUtR protocol with real libp2p hosts using circuit relay."""
|
||||
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.relay.circuit_v2.dcutr import (
|
||||
MAX_HOLE_PUNCH_ATTEMPTS,
|
||||
PROTOCOL_ID,
|
||||
DCUtRProtocol,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.pb.dcutr_pb2 import (
|
||||
HolePunch,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.protocol import (
|
||||
DEFAULT_RELAY_LIMITS,
|
||||
CircuitV2Protocol,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
from tests.utils.factories import (
|
||||
HostFactory,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Test timeouts
|
||||
SLEEP_TIME = 0.5 # seconds
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dcutr_through_relay_connection():
|
||||
"""
|
||||
Test DCUtR protocol where peers are connected via relay,
|
||||
then upgrade to direct.
|
||||
"""
|
||||
# Create three hosts: two peers and one relay
|
||||
async with HostFactory.create_batch_and_listen(3) as hosts:
|
||||
peer1, peer2, relay = hosts
|
||||
|
||||
# Create circuit relay protocol for the relay
|
||||
relay_protocol = CircuitV2Protocol(relay, DEFAULT_RELAY_LIMITS, allow_hop=True)
|
||||
|
||||
# Create DCUtR protocols for both peers
|
||||
dcutr1 = DCUtRProtocol(peer1)
|
||||
dcutr2 = DCUtRProtocol(peer2)
|
||||
|
||||
# Track if DCUtR stream handlers were called
|
||||
handler1_called = False
|
||||
handler2_called = False
|
||||
|
||||
# Override stream handlers to track calls
|
||||
original_handler1 = dcutr1._handle_dcutr_stream
|
||||
original_handler2 = dcutr2._handle_dcutr_stream
|
||||
|
||||
async def tracked_handler1(stream):
|
||||
nonlocal handler1_called
|
||||
handler1_called = True
|
||||
await original_handler1(stream)
|
||||
|
||||
async def tracked_handler2(stream):
|
||||
nonlocal handler2_called
|
||||
handler2_called = True
|
||||
await original_handler2(stream)
|
||||
|
||||
dcutr1._handle_dcutr_stream = tracked_handler1
|
||||
dcutr2._handle_dcutr_stream = tracked_handler2
|
||||
|
||||
# Start all protocols
|
||||
async with background_trio_service(relay_protocol):
|
||||
async with background_trio_service(dcutr1):
|
||||
async with background_trio_service(dcutr2):
|
||||
await relay_protocol.event_started.wait()
|
||||
await dcutr1.event_started.wait()
|
||||
await dcutr2.event_started.wait()
|
||||
|
||||
# Connect both peers to the relay
|
||||
relay_addrs = relay.get_addrs()
|
||||
|
||||
# Add relay addresses to both peers' peerstores
|
||||
for addr in relay_addrs:
|
||||
peer1.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
|
||||
peer2.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
|
||||
|
||||
# Connect peers to relay
|
||||
await peer1.connect(relay.get_peerstore().peer_info(relay.get_id()))
|
||||
await peer2.connect(relay.get_peerstore().peer_info(relay.get_id()))
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify peers are connected to relay
|
||||
assert relay.get_id() in [
|
||||
peer_id for peer_id in peer1.get_network().connections.keys()
|
||||
]
|
||||
assert relay.get_id() in [
|
||||
peer_id for peer_id in peer2.get_network().connections.keys()
|
||||
]
|
||||
|
||||
# Verify peers are NOT directly connected to each other
|
||||
assert peer2.get_id() not in [
|
||||
peer_id for peer_id in peer1.get_network().connections.keys()
|
||||
]
|
||||
assert peer1.get_id() not in [
|
||||
peer_id for peer_id in peer2.get_network().connections.keys()
|
||||
]
|
||||
|
||||
# Now test DCUtR: peer1 opens a DCUtR stream to peer2 through the
|
||||
# relay
|
||||
# This should trigger the DCUtR protocol for hole punching
|
||||
try:
|
||||
# Create a circuit relay multiaddr for peer2 through the relay
|
||||
relay_addr = relay_addrs[0]
|
||||
circuit_addr = Multiaddr(
|
||||
f"{relay_addr}/p2p-circuit/p2p/{peer2.get_id()}"
|
||||
)
|
||||
|
||||
# Add the circuit address to peer1's peerstore
|
||||
peer1.get_peerstore().add_addrs(
|
||||
peer2.get_id(), [circuit_addr], 3600
|
||||
)
|
||||
|
||||
# Open a DCUtR stream from peer1 to peer2 through the relay
|
||||
stream = await peer1.new_stream(peer2.get_id(), [PROTOCOL_ID])
|
||||
|
||||
# Send a CONNECT message with observed addresses
|
||||
peer1_addrs = peer1.get_addrs()
|
||||
connect_msg = HolePunch(
|
||||
type=HolePunch.CONNECT,
|
||||
ObsAddrs=[addr.to_bytes() for addr in peer1_addrs[:2]],
|
||||
)
|
||||
await stream.write(connect_msg.SerializeToString())
|
||||
|
||||
# Wait for the message to be processed
|
||||
await trio.sleep(0.2)
|
||||
|
||||
# Verify that the DCUtR stream handler was called on peer2
|
||||
assert handler2_called, (
|
||||
"DCUtR stream handler should have been called on peer2"
|
||||
)
|
||||
|
||||
# Close the stream
|
||||
await stream.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Expected error when trying to open DCUtR stream through "
|
||||
"relay: %s",
|
||||
e,
|
||||
)
|
||||
# This might fail because we need more setup, but the important
|
||||
# thing is testing the right scenario
|
||||
|
||||
# Wait a bit more
|
||||
await trio.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dcutr_relay_to_direct_upgrade():
|
||||
"""Test the complete flow: relay connection -> DCUtR -> direct connection."""
|
||||
# Create three hosts: two peers and one relay
|
||||
async with HostFactory.create_batch_and_listen(3) as hosts:
|
||||
peer1, peer2, relay = hosts
|
||||
|
||||
# Create circuit relay protocol for the relay
|
||||
relay_protocol = CircuitV2Protocol(relay, DEFAULT_RELAY_LIMITS, allow_hop=True)
|
||||
|
||||
# Create DCUtR protocols for both peers
|
||||
dcutr1 = DCUtRProtocol(peer1)
|
||||
dcutr2 = DCUtRProtocol(peer2)
|
||||
|
||||
# Track messages received
|
||||
messages_received = []
|
||||
|
||||
# Override stream handler to capture messages
|
||||
original_handler = dcutr2._handle_dcutr_stream
|
||||
|
||||
async def message_capturing_handler(stream):
|
||||
try:
|
||||
# Read the message
|
||||
msg_data = await stream.read()
|
||||
hole_punch = HolePunch()
|
||||
hole_punch.ParseFromString(msg_data)
|
||||
messages_received.append(hole_punch)
|
||||
|
||||
# Send a SYNC response
|
||||
sync_msg = HolePunch(type=HolePunch.SYNC)
|
||||
await stream.write(sync_msg.SerializeToString())
|
||||
|
||||
await original_handler(stream)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in message capturing handler: {e}")
|
||||
await stream.close()
|
||||
|
||||
dcutr2._handle_dcutr_stream = message_capturing_handler
|
||||
|
||||
# Start all protocols
|
||||
async with background_trio_service(relay_protocol):
|
||||
async with background_trio_service(dcutr1):
|
||||
async with background_trio_service(dcutr2):
|
||||
await relay_protocol.event_started.wait()
|
||||
await dcutr1.event_started.wait()
|
||||
await dcutr2.event_started.wait()
|
||||
|
||||
# Re-register the handler with the host
|
||||
dcutr2.host.set_stream_handler(
|
||||
PROTOCOL_ID, message_capturing_handler
|
||||
)
|
||||
|
||||
# Connect both peers to the relay
|
||||
relay_addrs = relay.get_addrs()
|
||||
|
||||
# Add relay addresses to both peers' peerstores
|
||||
for addr in relay_addrs:
|
||||
peer1.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
|
||||
peer2.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
|
||||
|
||||
# Connect peers to relay
|
||||
await peer1.connect(relay.get_peerstore().peer_info(relay.get_id()))
|
||||
await peer2.connect(relay.get_peerstore().peer_info(relay.get_id()))
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify peers are connected to relay but not to each other
|
||||
assert relay.get_id() in [
|
||||
peer_id for peer_id in peer1.get_network().connections.keys()
|
||||
]
|
||||
assert relay.get_id() in [
|
||||
peer_id for peer_id in peer2.get_network().connections.keys()
|
||||
]
|
||||
assert peer2.get_id() not in [
|
||||
peer_id for peer_id in peer1.get_network().connections.keys()
|
||||
]
|
||||
|
||||
# Try to open a DCUtR stream through the relay
|
||||
try:
|
||||
# Create a circuit relay multiaddr for peer2 through the relay
|
||||
relay_addr = relay_addrs[0]
|
||||
circuit_addr = Multiaddr(
|
||||
f"{relay_addr}/p2p-circuit/p2p/{peer2.get_id()}"
|
||||
)
|
||||
|
||||
# Add the circuit address to peer1's peerstore
|
||||
peer1.get_peerstore().add_addrs(
|
||||
peer2.get_id(), [circuit_addr], 3600
|
||||
)
|
||||
|
||||
# Open a DCUtR stream from peer1 to peer2 through the relay
|
||||
stream = await peer1.new_stream(peer2.get_id(), [PROTOCOL_ID])
|
||||
|
||||
# Send a CONNECT message with observed addresses
|
||||
peer1_addrs = peer1.get_addrs()
|
||||
connect_msg = HolePunch(
|
||||
type=HolePunch.CONNECT,
|
||||
ObsAddrs=[addr.to_bytes() for addr in peer1_addrs[:2]],
|
||||
)
|
||||
await stream.write(connect_msg.SerializeToString())
|
||||
|
||||
# Wait for the message to be processed
|
||||
await trio.sleep(0.2)
|
||||
|
||||
# Verify that the CONNECT message was received
|
||||
assert len(messages_received) == 1, (
|
||||
"Should have received one message"
|
||||
)
|
||||
assert messages_received[0].type == HolePunch.CONNECT, (
|
||||
"Should have received CONNECT message"
|
||||
)
|
||||
assert len(messages_received[0].ObsAddrs) == 2, (
|
||||
"Should have received 2 observed addresses"
|
||||
)
|
||||
|
||||
# Close the stream
|
||||
await stream.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Expected error when trying to open DCUtR stream through "
|
||||
"relay: %s",
|
||||
e,
|
||||
)
|
||||
|
||||
# Wait a bit more
|
||||
await trio.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dcutr_hole_punch_through_relay():
|
||||
"""Test hole punching when peers are connected through relay."""
|
||||
# Create three hosts: two peers and one relay
|
||||
async with HostFactory.create_batch_and_listen(3) as hosts:
|
||||
peer1, peer2, relay = hosts
|
||||
|
||||
# Create circuit relay protocol for the relay
|
||||
relay_protocol = CircuitV2Protocol(relay, DEFAULT_RELAY_LIMITS, allow_hop=True)
|
||||
|
||||
# Create DCUtR protocols for both peers
|
||||
dcutr1 = DCUtRProtocol(peer1)
|
||||
dcutr2 = DCUtRProtocol(peer2)
|
||||
|
||||
# Start all protocols
|
||||
async with background_trio_service(relay_protocol):
|
||||
async with background_trio_service(dcutr1):
|
||||
async with background_trio_service(dcutr2):
|
||||
await relay_protocol.event_started.wait()
|
||||
await dcutr1.event_started.wait()
|
||||
await dcutr2.event_started.wait()
|
||||
|
||||
# Connect both peers to the relay
|
||||
relay_addrs = relay.get_addrs()
|
||||
|
||||
# Add relay addresses to both peers' peerstores
|
||||
for addr in relay_addrs:
|
||||
peer1.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
|
||||
peer2.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
|
||||
|
||||
# Connect peers to relay
|
||||
await peer1.connect(relay.get_peerstore().peer_info(relay.get_id()))
|
||||
await peer2.connect(relay.get_peerstore().peer_info(relay.get_id()))
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify peers are connected to relay but not to each other
|
||||
assert relay.get_id() in [
|
||||
peer_id for peer_id in peer1.get_network().connections.keys()
|
||||
]
|
||||
assert relay.get_id() in [
|
||||
peer_id for peer_id in peer2.get_network().connections.keys()
|
||||
]
|
||||
assert peer2.get_id() not in [
|
||||
peer_id for peer_id in peer1.get_network().connections.keys()
|
||||
]
|
||||
|
||||
# Check if there's already a direct connection (should be False)
|
||||
has_direct = await dcutr1._have_direct_connection(peer2.get_id())
|
||||
assert not has_direct, "Peers should not have a direct connection"
|
||||
|
||||
# Try to initiate a hole punch (this should work through the relay
|
||||
# connection)
|
||||
# In a real scenario, this would be called after establishing a
|
||||
# relay connection
|
||||
result = await dcutr1.initiate_hole_punch(peer2.get_id())
|
||||
|
||||
# This should attempt hole punching but likely fail due to no public
|
||||
# addresses
|
||||
# The important thing is that the DCUtR protocol logic is executed
|
||||
logger.info(
|
||||
"Hole punch result: %s",
|
||||
result,
|
||||
)
|
||||
|
||||
assert result is not None, "Hole punch result should not be None"
|
||||
assert isinstance(result, bool), (
|
||||
"Hole punch result should be a boolean"
|
||||
)
|
||||
|
||||
# Wait a bit more
|
||||
await trio.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dcutr_relay_connection_verification():
|
||||
"""Test that DCUtR works correctly when peers are connected via relay."""
|
||||
# Create three hosts: two peers and one relay
|
||||
async with HostFactory.create_batch_and_listen(3) as hosts:
|
||||
peer1, peer2, relay = hosts
|
||||
|
||||
# Create circuit relay protocol for the relay
|
||||
relay_protocol = CircuitV2Protocol(relay, DEFAULT_RELAY_LIMITS, allow_hop=True)
|
||||
|
||||
# Create DCUtR protocols for both peers
|
||||
dcutr1 = DCUtRProtocol(peer1)
|
||||
dcutr2 = DCUtRProtocol(peer2)
|
||||
|
||||
# Start all protocols
|
||||
async with background_trio_service(relay_protocol):
|
||||
async with background_trio_service(dcutr1):
|
||||
async with background_trio_service(dcutr2):
|
||||
await relay_protocol.event_started.wait()
|
||||
await dcutr1.event_started.wait()
|
||||
await dcutr2.event_started.wait()
|
||||
|
||||
# Connect both peers to the relay
|
||||
relay_addrs = relay.get_addrs()
|
||||
|
||||
# Add relay addresses to both peers' peerstores
|
||||
for addr in relay_addrs:
|
||||
peer1.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
|
||||
peer2.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
|
||||
|
||||
# Connect peers to relay
|
||||
await peer1.connect(relay.get_peerstore().peer_info(relay.get_id()))
|
||||
await peer2.connect(relay.get_peerstore().peer_info(relay.get_id()))
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify peers are connected to relay
|
||||
assert relay.get_id() in [
|
||||
peer_id for peer_id in peer1.get_network().connections.keys()
|
||||
]
|
||||
assert relay.get_id() in [
|
||||
peer_id for peer_id in peer2.get_network().connections.keys()
|
||||
]
|
||||
|
||||
# Verify peers are NOT directly connected to each other
|
||||
assert peer2.get_id() not in [
|
||||
peer_id for peer_id in peer1.get_network().connections.keys()
|
||||
]
|
||||
assert peer1.get_id() not in [
|
||||
peer_id for peer_id in peer2.get_network().connections.keys()
|
||||
]
|
||||
|
||||
# Test getting observed addresses (real implementation)
|
||||
observed_addrs1 = await dcutr1._get_observed_addrs()
|
||||
observed_addrs2 = await dcutr2._get_observed_addrs()
|
||||
|
||||
assert isinstance(observed_addrs1, list)
|
||||
assert isinstance(observed_addrs2, list)
|
||||
|
||||
# Should contain the hosts' actual addresses
|
||||
assert len(observed_addrs1) > 0, (
|
||||
"Peer1 should have observed addresses"
|
||||
)
|
||||
assert len(observed_addrs2) > 0, (
|
||||
"Peer2 should have observed addresses"
|
||||
)
|
||||
|
||||
# Test decoding observed addresses
|
||||
test_addrs = [
|
||||
Multiaddr("/ip4/127.0.0.1/tcp/1234").to_bytes(),
|
||||
Multiaddr("/ip4/192.168.1.1/tcp/5678").to_bytes(),
|
||||
b"invalid-addr", # This should be filtered out
|
||||
]
|
||||
decoded = dcutr1._decode_observed_addrs(test_addrs)
|
||||
assert len(decoded) == 2, "Should decode 2 valid addresses"
|
||||
assert all(str(addr).startswith("/ip4/") for addr in decoded)
|
||||
|
||||
# Wait a bit more
|
||||
await trio.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dcutr_relay_error_handling():
|
||||
"""Test DCUtR error handling when working through relay connections."""
|
||||
# Create three hosts: two peers and one relay
|
||||
async with HostFactory.create_batch_and_listen(3) as hosts:
|
||||
peer1, peer2, relay = hosts
|
||||
|
||||
# Create circuit relay protocol for the relay
|
||||
relay_protocol = CircuitV2Protocol(relay, DEFAULT_RELAY_LIMITS, allow_hop=True)
|
||||
|
||||
# Create DCUtR protocols for both peers
|
||||
dcutr1 = DCUtRProtocol(peer1)
|
||||
dcutr2 = DCUtRProtocol(peer2)
|
||||
|
||||
# Start all protocols
|
||||
async with background_trio_service(relay_protocol):
|
||||
async with background_trio_service(dcutr1):
|
||||
async with background_trio_service(dcutr2):
|
||||
await relay_protocol.event_started.wait()
|
||||
await dcutr1.event_started.wait()
|
||||
await dcutr2.event_started.wait()
|
||||
|
||||
# Connect both peers to the relay
|
||||
relay_addrs = relay.get_addrs()
|
||||
|
||||
# Add relay addresses to both peers' peerstores
|
||||
for addr in relay_addrs:
|
||||
peer1.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
|
||||
peer2.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
|
||||
|
||||
# Connect peers to relay
|
||||
await peer1.connect(relay.get_peerstore().peer_info(relay.get_id()))
|
||||
await peer2.connect(relay.get_peerstore().peer_info(relay.get_id()))
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Test with a stream that times out
|
||||
timeout_stream = MagicMock()
|
||||
timeout_stream.muxed_conn.peer_id = peer2.get_id()
|
||||
timeout_stream.read = AsyncMock(side_effect=trio.TooSlowError())
|
||||
timeout_stream.write = AsyncMock()
|
||||
timeout_stream.close = AsyncMock()
|
||||
|
||||
# This should not raise an exception, just log and close
|
||||
await dcutr1._handle_dcutr_stream(timeout_stream)
|
||||
|
||||
# Verify stream was closed
|
||||
assert timeout_stream.close.called
|
||||
|
||||
# Test with malformed message
|
||||
malformed_stream = MagicMock()
|
||||
malformed_stream.muxed_conn.peer_id = peer2.get_id()
|
||||
malformed_stream.read = AsyncMock(return_value=b"not-a-protobuf")
|
||||
malformed_stream.write = AsyncMock()
|
||||
malformed_stream.close = AsyncMock()
|
||||
|
||||
# This should not raise an exception, just log and close
|
||||
await dcutr1._handle_dcutr_stream(malformed_stream)
|
||||
|
||||
# Verify stream was closed
|
||||
assert malformed_stream.close.called
|
||||
|
||||
# Wait a bit more
|
||||
await trio.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dcutr_relay_attempt_limiting():
|
||||
"""Test DCUtR attempt limiting when working through relay connections."""
|
||||
# Create three hosts: two peers and one relay
|
||||
async with HostFactory.create_batch_and_listen(3) as hosts:
|
||||
peer1, peer2, relay = hosts
|
||||
|
||||
# Create circuit relay protocol for the relay
|
||||
relay_protocol = CircuitV2Protocol(relay, DEFAULT_RELAY_LIMITS, allow_hop=True)
|
||||
|
||||
# Create DCUtR protocols for both peers
|
||||
dcutr1 = DCUtRProtocol(peer1)
|
||||
dcutr2 = DCUtRProtocol(peer2)
|
||||
|
||||
# Start all protocols
|
||||
async with background_trio_service(relay_protocol):
|
||||
async with background_trio_service(dcutr1):
|
||||
async with background_trio_service(dcutr2):
|
||||
await relay_protocol.event_started.wait()
|
||||
await dcutr1.event_started.wait()
|
||||
await dcutr2.event_started.wait()
|
||||
|
||||
# Connect both peers to the relay
|
||||
relay_addrs = relay.get_addrs()
|
||||
|
||||
# Add relay addresses to both peers' peerstores
|
||||
for addr in relay_addrs:
|
||||
peer1.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
|
||||
peer2.get_peerstore().add_addrs(relay.get_id(), [addr], 3600)
|
||||
|
||||
# Connect peers to relay
|
||||
await peer1.connect(relay.get_peerstore().peer_info(relay.get_id()))
|
||||
await peer2.connect(relay.get_peerstore().peer_info(relay.get_id()))
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Set max attempts reached
|
||||
dcutr1._hole_punch_attempts[peer2.get_id()] = (
|
||||
MAX_HOLE_PUNCH_ATTEMPTS
|
||||
)
|
||||
|
||||
# Try to initiate hole punch - should fail due to max attempts
|
||||
result = await dcutr1.initiate_hole_punch(peer2.get_id())
|
||||
assert result is False, "Hole punch should fail due to max attempts"
|
||||
|
||||
# Reset attempts
|
||||
dcutr1._hole_punch_attempts.clear()
|
||||
|
||||
# Add to direct connections
|
||||
dcutr1._direct_connections.add(peer2.get_id())
|
||||
|
||||
# Try to initiate hole punch - should succeed immediately
|
||||
result = await dcutr1.initiate_hole_punch(peer2.get_id())
|
||||
assert result is True, (
|
||||
"Hole punch should succeed for already connected peers"
|
||||
)
|
||||
|
||||
# Wait a bit more
|
||||
await trio.sleep(0.1)
|
||||
208
tests/core/relay/test_dcutr_protocol.py
Normal file
208
tests/core/relay/test_dcutr_protocol.py
Normal file
@ -0,0 +1,208 @@
|
||||
"""Unit tests for DCUtR protocol."""
|
||||
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.abc import INetStream
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.relay.circuit_v2.dcutr import (
|
||||
MAX_HOLE_PUNCH_ATTEMPTS,
|
||||
DCUtRProtocol,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.pb.dcutr_pb2 import HolePunch
|
||||
from libp2p.tools.async_service import background_trio_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dcutr_protocol_initialization():
|
||||
"""Test DCUtR protocol initialization."""
|
||||
mock_host = MagicMock()
|
||||
dcutr = DCUtRProtocol(mock_host)
|
||||
|
||||
# Test that the protocol is initialized correctly
|
||||
assert dcutr.host == mock_host
|
||||
assert not dcutr.event_started.is_set()
|
||||
assert dcutr._hole_punch_attempts == {}
|
||||
assert dcutr._direct_connections == set()
|
||||
assert dcutr._in_progress == set()
|
||||
|
||||
# Test that the protocol can be started
|
||||
async with background_trio_service(dcutr):
|
||||
# Wait for the protocol to start
|
||||
await dcutr.event_started.wait()
|
||||
|
||||
# Verify that the stream handler was registered
|
||||
mock_host.set_stream_handler.assert_called_once()
|
||||
|
||||
# Verify that the event is set
|
||||
assert dcutr.event_started.is_set()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dcutr_message_exchange():
|
||||
"""Test DCUtR message exchange."""
|
||||
mock_host = MagicMock()
|
||||
dcutr = DCUtRProtocol(mock_host)
|
||||
|
||||
# Test that the protocol can be started
|
||||
async with background_trio_service(dcutr):
|
||||
# Wait for the protocol to start
|
||||
await dcutr.event_started.wait()
|
||||
|
||||
# Test CONNECT message
|
||||
connect_msg = HolePunch(
|
||||
type=HolePunch.CONNECT,
|
||||
ObsAddrs=[b"/ip4/127.0.0.1/tcp/1234", b"/ip4/192.168.1.1/tcp/5678"],
|
||||
)
|
||||
|
||||
# Test SYNC message
|
||||
sync_msg = HolePunch(type=HolePunch.SYNC)
|
||||
|
||||
# Verify message types
|
||||
assert connect_msg.type == HolePunch.CONNECT
|
||||
assert sync_msg.type == HolePunch.SYNC
|
||||
assert len(connect_msg.ObsAddrs) == 2
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dcutr_error_handling(monkeypatch):
|
||||
"""Test DCUtR error handling."""
|
||||
mock_host = MagicMock()
|
||||
dcutr = DCUtRProtocol(mock_host)
|
||||
|
||||
async with background_trio_service(dcutr):
|
||||
await dcutr.event_started.wait()
|
||||
|
||||
# Simulate a stream that times out
|
||||
class TimeoutStream(INetStream):
|
||||
def __init__(self):
|
||||
self._protocol = None
|
||||
self.muxed_conn = MagicMock(peer_id=ID(b"peer"))
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
await trio.sleep(0.2)
|
||||
raise trio.TooSlowError()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
return None
|
||||
|
||||
async def close(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def reset(self):
|
||||
return None
|
||||
|
||||
def get_protocol(self):
|
||||
return self._protocol
|
||||
|
||||
def set_protocol(self, protocol_id):
|
||||
self._protocol = protocol_id
|
||||
|
||||
def get_remote_address(self):
|
||||
return ("127.0.0.1", 1234)
|
||||
|
||||
# Should not raise, just log and close
|
||||
await dcutr._handle_dcutr_stream(TimeoutStream())
|
||||
|
||||
# Simulate a stream with malformed message
|
||||
class MalformedStream(INetStream):
|
||||
def __init__(self):
|
||||
self._protocol = None
|
||||
self.muxed_conn = MagicMock(peer_id=ID(b"peer"))
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
return b"not-a-protobuf"
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
return None
|
||||
|
||||
async def close(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def reset(self):
|
||||
return None
|
||||
|
||||
def get_protocol(self):
|
||||
return self._protocol
|
||||
|
||||
def set_protocol(self, protocol_id):
|
||||
self._protocol = protocol_id
|
||||
|
||||
def get_remote_address(self):
|
||||
return ("127.0.0.1", 1234)
|
||||
|
||||
await dcutr._handle_dcutr_stream(MalformedStream())
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dcutr_max_attempts_and_already_connected():
|
||||
"""Test max hole punch attempts and already-connected peer."""
|
||||
mock_host = MagicMock()
|
||||
dcutr = DCUtRProtocol(mock_host)
|
||||
peer_id = ID(b"peer")
|
||||
|
||||
# Simulate already having a direct connection
|
||||
dcutr._direct_connections.add(peer_id)
|
||||
result = await dcutr.initiate_hole_punch(peer_id)
|
||||
assert result is True
|
||||
|
||||
# Remove direct connection, simulate max attempts
|
||||
dcutr._direct_connections.clear()
|
||||
dcutr._hole_punch_attempts[peer_id] = MAX_HOLE_PUNCH_ATTEMPTS
|
||||
result = await dcutr.initiate_hole_punch(peer_id)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dcutr_observed_addr_encoding_decoding():
|
||||
"""Test observed address encoding/decoding."""
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
mock_host = MagicMock()
|
||||
dcutr = DCUtRProtocol(mock_host)
|
||||
# Simulate valid and invalid multiaddrs as bytes
|
||||
valid = [
|
||||
Multiaddr("/ip4/127.0.0.1/tcp/1234").to_bytes(),
|
||||
Multiaddr("/ip4/192.168.1.1/tcp/5678").to_bytes(),
|
||||
]
|
||||
invalid = [b"not-a-multiaddr", b""]
|
||||
decoded = dcutr._decode_observed_addrs(valid + invalid)
|
||||
assert len(decoded) == 2
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dcutr_real_perform_hole_punch(monkeypatch):
|
||||
"""Test initiate_hole_punch with real _perform_hole_punch logic (mock network)."""
|
||||
mock_host = MagicMock()
|
||||
dcutr = DCUtRProtocol(mock_host)
|
||||
peer_id = ID(b"peer")
|
||||
|
||||
# Patch methods to simulate a successful punch
|
||||
monkeypatch.setattr(dcutr, "_have_direct_connection", AsyncMock(return_value=False))
|
||||
monkeypatch.setattr(
|
||||
dcutr,
|
||||
"_get_observed_addrs",
|
||||
AsyncMock(return_value=[b"/ip4/127.0.0.1/tcp/1234"]),
|
||||
)
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.read = AsyncMock(
|
||||
side_effect=[
|
||||
HolePunch(
|
||||
type=HolePunch.CONNECT, ObsAddrs=[b"/ip4/192.168.1.1/tcp/4321"]
|
||||
).SerializeToString(),
|
||||
HolePunch(type=HolePunch.SYNC).SerializeToString(),
|
||||
]
|
||||
)
|
||||
mock_stream.write = AsyncMock()
|
||||
mock_stream.close = AsyncMock()
|
||||
mock_stream.muxed_conn = MagicMock(peer_id=peer_id)
|
||||
mock_host.new_stream = AsyncMock(return_value=mock_stream)
|
||||
monkeypatch.setattr(dcutr, "_perform_hole_punch", AsyncMock(return_value=True))
|
||||
|
||||
result = await dcutr.initiate_hole_punch(peer_id)
|
||||
assert result is True
|
||||
297
tests/core/relay/test_nat.py
Normal file
297
tests/core/relay/test_nat.py
Normal file
@ -0,0 +1,297 @@
|
||||
"""Tests for NAT traversal utilities."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.relay.circuit_v2.nat import (
|
||||
ReachabilityChecker,
|
||||
extract_ip_from_multiaddr,
|
||||
ip_to_int,
|
||||
is_ip_in_range,
|
||||
is_private_ip,
|
||||
)
|
||||
|
||||
|
||||
def test_ip_to_int_ipv4():
|
||||
"""Test converting IPv4 addresses to integers."""
|
||||
assert ip_to_int("192.168.1.1") == 3232235777
|
||||
assert ip_to_int("10.0.0.1") == 167772161
|
||||
assert ip_to_int("127.0.0.1") == 2130706433
|
||||
|
||||
|
||||
def test_ip_to_int_ipv6():
|
||||
"""Test converting IPv6 addresses to integers."""
|
||||
# Test with a simple IPv6 address
|
||||
ipv6_int = ip_to_int("::1")
|
||||
assert isinstance(ipv6_int, int)
|
||||
assert ipv6_int > 0
|
||||
|
||||
|
||||
def test_ip_to_int_invalid():
|
||||
"""Test handling of invalid IP addresses."""
|
||||
with pytest.raises(ValueError):
|
||||
ip_to_int("invalid-ip")
|
||||
|
||||
|
||||
def test_is_ip_in_range():
|
||||
"""Test IP range checking."""
|
||||
# Test within range
|
||||
assert is_ip_in_range("192.168.1.5", "192.168.1.1", "192.168.1.10") is True
|
||||
assert is_ip_in_range("10.0.0.5", "10.0.0.0", "10.0.0.255") is True
|
||||
|
||||
# Test outside range
|
||||
assert is_ip_in_range("192.168.2.5", "192.168.1.1", "192.168.1.10") is False
|
||||
assert is_ip_in_range("8.8.8.8", "10.0.0.0", "10.0.0.255") is False
|
||||
|
||||
|
||||
def test_is_ip_in_range_invalid():
|
||||
"""Test IP range checking with invalid inputs."""
|
||||
assert is_ip_in_range("invalid", "192.168.1.1", "192.168.1.10") is False
|
||||
assert is_ip_in_range("192.168.1.5", "invalid", "192.168.1.10") is False
|
||||
|
||||
|
||||
def test_is_private_ip():
|
||||
"""Test private IP detection."""
|
||||
# Private IPs
|
||||
assert is_private_ip("192.168.1.1") is True
|
||||
assert is_private_ip("10.0.0.1") is True
|
||||
assert is_private_ip("172.16.0.1") is True
|
||||
assert is_private_ip("127.0.0.1") is True # Loopback
|
||||
assert is_private_ip("169.254.1.1") is True # Link-local
|
||||
|
||||
# Public IPs
|
||||
assert is_private_ip("8.8.8.8") is False
|
||||
assert is_private_ip("1.1.1.1") is False
|
||||
assert is_private_ip("208.67.222.222") is False
|
||||
|
||||
|
||||
def test_extract_ip_from_multiaddr():
|
||||
"""Test IP extraction from multiaddrs."""
|
||||
# IPv4 addresses
|
||||
addr1 = Multiaddr("/ip4/192.168.1.1/tcp/1234")
|
||||
assert extract_ip_from_multiaddr(addr1) == "192.168.1.1"
|
||||
|
||||
addr2 = Multiaddr("/ip4/10.0.0.1/udp/5678")
|
||||
assert extract_ip_from_multiaddr(addr2) == "10.0.0.1"
|
||||
|
||||
# IPv6 addresses
|
||||
addr3 = Multiaddr("/ip6/::1/tcp/1234")
|
||||
assert extract_ip_from_multiaddr(addr3) == "::1"
|
||||
|
||||
addr4 = Multiaddr("/ip6/2001:db8::1/udp/5678")
|
||||
assert extract_ip_from_multiaddr(addr4) == "2001:db8::1"
|
||||
|
||||
# No IP address
|
||||
addr5 = Multiaddr("/dns4/example.com/tcp/1234")
|
||||
assert extract_ip_from_multiaddr(addr5) is None
|
||||
|
||||
# Complex multiaddr (without p2p to avoid base58 issues)
|
||||
addr6 = Multiaddr("/ip4/192.168.1.1/tcp/1234/udp/5678")
|
||||
assert extract_ip_from_multiaddr(addr6) == "192.168.1.1"
|
||||
|
||||
|
||||
def test_reachability_checker_init():
|
||||
"""Test ReachabilityChecker initialization."""
|
||||
mock_host = MagicMock()
|
||||
checker = ReachabilityChecker(mock_host)
|
||||
|
||||
assert checker.host == mock_host
|
||||
assert checker._peer_reachability == {}
|
||||
assert checker._known_public_peers == set()
|
||||
|
||||
|
||||
def test_reachability_checker_is_addr_public():
|
||||
"""Test public address detection."""
|
||||
mock_host = MagicMock()
|
||||
checker = ReachabilityChecker(mock_host)
|
||||
|
||||
# Public addresses
|
||||
public_addr1 = Multiaddr("/ip4/8.8.8.8/tcp/1234")
|
||||
assert checker.is_addr_public(public_addr1) is True
|
||||
|
||||
public_addr2 = Multiaddr("/ip4/1.1.1.1/udp/5678")
|
||||
assert checker.is_addr_public(public_addr2) is True
|
||||
|
||||
# Private addresses
|
||||
private_addr1 = Multiaddr("/ip4/192.168.1.1/tcp/1234")
|
||||
assert checker.is_addr_public(private_addr1) is False
|
||||
|
||||
private_addr2 = Multiaddr("/ip4/10.0.0.1/udp/5678")
|
||||
assert checker.is_addr_public(private_addr2) is False
|
||||
|
||||
private_addr3 = Multiaddr("/ip4/127.0.0.1/tcp/1234")
|
||||
assert checker.is_addr_public(private_addr3) is False
|
||||
|
||||
# No IP address
|
||||
dns_addr = Multiaddr("/dns4/example.com/tcp/1234")
|
||||
assert checker.is_addr_public(dns_addr) is False
|
||||
|
||||
|
||||
def test_reachability_checker_get_public_addrs():
|
||||
"""Test filtering for public addresses."""
|
||||
mock_host = MagicMock()
|
||||
checker = ReachabilityChecker(mock_host)
|
||||
|
||||
addrs = [
|
||||
Multiaddr("/ip4/8.8.8.8/tcp/1234"), # Public
|
||||
Multiaddr("/ip4/192.168.1.1/tcp/1234"), # Private
|
||||
Multiaddr("/ip4/1.1.1.1/udp/5678"), # Public
|
||||
Multiaddr("/ip4/10.0.0.1/tcp/1234"), # Private
|
||||
Multiaddr("/dns4/example.com/tcp/1234"), # DNS
|
||||
]
|
||||
|
||||
public_addrs = checker.get_public_addrs(addrs)
|
||||
assert len(public_addrs) == 2
|
||||
assert Multiaddr("/ip4/8.8.8.8/tcp/1234") in public_addrs
|
||||
assert Multiaddr("/ip4/1.1.1.1/udp/5678") in public_addrs
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_check_peer_reachability_connected_direct():
|
||||
"""Test peer reachability when directly connected."""
|
||||
mock_host = MagicMock()
|
||||
mock_network = MagicMock()
|
||||
mock_host.get_network.return_value = mock_network
|
||||
|
||||
peer_id = ID(b"test-peer-id")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.get_transport_addresses.return_value = [
|
||||
Multiaddr("/ip4/192.168.1.1/tcp/1234") # Direct connection
|
||||
]
|
||||
|
||||
mock_network.connections = {peer_id: mock_conn}
|
||||
|
||||
checker = ReachabilityChecker(mock_host)
|
||||
result = await checker.check_peer_reachability(peer_id)
|
||||
|
||||
assert result is True
|
||||
assert checker._peer_reachability[peer_id] is True
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_check_peer_reachability_connected_relay():
|
||||
"""Test peer reachability when connected through relay."""
|
||||
mock_host = MagicMock()
|
||||
mock_network = MagicMock()
|
||||
mock_host.get_network.return_value = mock_network
|
||||
|
||||
peer_id = ID(b"test-peer-id")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.get_transport_addresses.return_value = [
|
||||
Multiaddr("/p2p-circuit/ip4/192.168.1.1/tcp/1234") # Relay connection
|
||||
]
|
||||
|
||||
mock_network.connections = {peer_id: mock_conn}
|
||||
|
||||
# Mock peerstore with public addresses
|
||||
mock_peerstore = MagicMock()
|
||||
mock_peerstore.addrs.return_value = [
|
||||
Multiaddr("/ip4/8.8.8.8/tcp/1234") # Public address
|
||||
]
|
||||
mock_host.get_peerstore.return_value = mock_peerstore
|
||||
|
||||
checker = ReachabilityChecker(mock_host)
|
||||
result = await checker.check_peer_reachability(peer_id)
|
||||
|
||||
assert result is True
|
||||
assert checker._peer_reachability[peer_id] is True
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_check_peer_reachability_not_connected():
|
||||
"""Test peer reachability when not connected."""
|
||||
mock_host = MagicMock()
|
||||
mock_network = MagicMock()
|
||||
mock_host.get_network.return_value = mock_network
|
||||
|
||||
peer_id = ID(b"test-peer-id")
|
||||
mock_network.connections = {} # No connections
|
||||
|
||||
checker = ReachabilityChecker(mock_host)
|
||||
result = await checker.check_peer_reachability(peer_id)
|
||||
|
||||
assert result is False
|
||||
# When not connected, the method doesn't add to cache
|
||||
assert peer_id not in checker._peer_reachability
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_check_peer_reachability_cached():
|
||||
"""Test that peer reachability results are cached."""
|
||||
mock_host = MagicMock()
|
||||
checker = ReachabilityChecker(mock_host)
|
||||
|
||||
peer_id = ID(b"test-peer-id")
|
||||
checker._peer_reachability[peer_id] = True
|
||||
|
||||
result = await checker.check_peer_reachability(peer_id)
|
||||
assert result is True
|
||||
|
||||
# Should not call host methods when cached
|
||||
mock_host.get_network.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_check_self_reachability_with_public_addrs():
|
||||
"""Test self reachability when host has public addresses."""
|
||||
mock_host = MagicMock()
|
||||
mock_host.get_addrs.return_value = [
|
||||
Multiaddr("/ip4/8.8.8.8/tcp/1234"), # Public
|
||||
Multiaddr("/ip4/192.168.1.1/tcp/1234"), # Private
|
||||
Multiaddr("/ip4/1.1.1.1/udp/5678"), # Public
|
||||
]
|
||||
|
||||
checker = ReachabilityChecker(mock_host)
|
||||
is_reachable, public_addrs = await checker.check_self_reachability()
|
||||
|
||||
assert is_reachable is True
|
||||
assert len(public_addrs) == 2
|
||||
assert Multiaddr("/ip4/8.8.8.8/tcp/1234") in public_addrs
|
||||
assert Multiaddr("/ip4/1.1.1.1/udp/5678") in public_addrs
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_check_self_reachability_no_public_addrs():
|
||||
"""Test self reachability when host has no public addresses."""
|
||||
mock_host = MagicMock()
|
||||
mock_host.get_addrs.return_value = [
|
||||
Multiaddr("/ip4/192.168.1.1/tcp/1234"), # Private
|
||||
Multiaddr("/ip4/10.0.0.1/udp/5678"), # Private
|
||||
Multiaddr("/ip4/127.0.0.1/tcp/1234"), # Loopback
|
||||
]
|
||||
|
||||
checker = ReachabilityChecker(mock_host)
|
||||
is_reachable, public_addrs = await checker.check_self_reachability()
|
||||
|
||||
assert is_reachable is False
|
||||
assert len(public_addrs) == 0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_check_peer_reachability_multiple_connections():
|
||||
"""Test peer reachability with multiple connections."""
|
||||
mock_host = MagicMock()
|
||||
mock_network = MagicMock()
|
||||
mock_host.get_network.return_value = mock_network
|
||||
|
||||
peer_id = ID(b"test-peer-id")
|
||||
mock_conn1 = MagicMock()
|
||||
mock_conn1.get_transport_addresses.return_value = [
|
||||
Multiaddr("/p2p-circuit/ip4/192.168.1.1/tcp/1234") # Relay
|
||||
]
|
||||
|
||||
mock_conn2 = MagicMock()
|
||||
mock_conn2.get_transport_addresses.return_value = [
|
||||
Multiaddr("/ip4/192.168.1.1/tcp/1234") # Direct
|
||||
]
|
||||
|
||||
mock_network.connections = {peer_id: [mock_conn1, mock_conn2]}
|
||||
|
||||
checker = ReachabilityChecker(mock_host)
|
||||
result = await checker.check_peer_reachability(peer_id)
|
||||
|
||||
assert result is True
|
||||
assert checker._peer_reachability[peer_id] is True
|
||||
Reference in New Issue
Block a user