mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-11 15:40:54 +00:00
Compare commits
28 Commits
333d56dc00
...
chore-02
| Author | SHA1 | Date | |
|---|---|---|---|
| 17ed7f82fb | |||
| 2a249b1792 | |||
| 5ec1671608 | |||
| 431a4807fb | |||
| f54a14b713 | |||
| 37a4d96f90 | |||
| 7bddb08808 | |||
| d385cb45cf | |||
| 14a74fdbd1 | |||
| 145727a9ba | |||
| 84c1a7031a | |||
| 6742dd38f7 | |||
| fcb35084b3 | |||
| 42c8937a8d | |||
| 64ccce17eb | |||
| 6a24b138dd | |||
| 8f5dd3bd11 | |||
| 997094e5b7 | |||
| dc205bff83 | |||
| 3d1c36419c | |||
| 1037fbb0aa | |||
| c940dac1e6 | |||
| 3b27b02a8b | |||
| 5a2fca32a0 | |||
| 8d9b7f413d | |||
| a9f184be6a | |||
| 6931092eea | |||
| 163cc35cb0 |
@ -24,8 +24,13 @@ async def main():
|
|||||||
noise_transport = NoiseTransport(
|
noise_transport = NoiseTransport(
|
||||||
# local_key_pair: The key pair used for libp2p identity and authentication
|
# local_key_pair: The key pair used for libp2p identity and authentication
|
||||||
libp2p_keypair=key_pair,
|
libp2p_keypair=key_pair,
|
||||||
|
# noise_privkey: The private key used for Noise protocol encryption
|
||||||
noise_privkey=key_pair.private_key,
|
noise_privkey=key_pair.private_key,
|
||||||
# TODO: add early data
|
# early_data: Optional data to send during the handshake
|
||||||
|
# (None means no early data)
|
||||||
|
early_data=None,
|
||||||
|
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a security options dictionary mapping protocol ID to transport
|
# Create a security options dictionary mapping protocol ID to transport
|
||||||
|
|||||||
@ -28,7 +28,9 @@ async def main():
|
|||||||
noise_privkey=key_pair.private_key,
|
noise_privkey=key_pair.private_key,
|
||||||
# early_data: Optional data to send during the handshake
|
# early_data: Optional data to send during the handshake
|
||||||
# (None means no early data)
|
# (None means no early data)
|
||||||
# TODO: add early data
|
early_data=None,
|
||||||
|
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a security options dictionary mapping protocol ID to transport
|
# Create a security options dictionary mapping protocol ID to transport
|
||||||
|
|||||||
@ -31,7 +31,9 @@ async def main():
|
|||||||
noise_privkey=key_pair.private_key,
|
noise_privkey=key_pair.private_key,
|
||||||
# early_data: Optional data to send during the handshake
|
# early_data: Optional data to send during the handshake
|
||||||
# (None means no early data)
|
# (None means no early data)
|
||||||
# TODO: add early data
|
early_data=None,
|
||||||
|
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a security options dictionary mapping protocol ID to transport
|
# Create a security options dictionary mapping protocol ID to transport
|
||||||
|
|||||||
@ -28,7 +28,9 @@ async def main():
|
|||||||
noise_privkey=key_pair.private_key,
|
noise_privkey=key_pair.private_key,
|
||||||
# early_data: Optional data to send during the handshake
|
# early_data: Optional data to send during the handshake
|
||||||
# (None means no early data)
|
# (None means no early data)
|
||||||
# TODO: add early data
|
early_data=None,
|
||||||
|
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a security options dictionary mapping protocol ID to transport
|
# Create a security options dictionary mapping protocol ID to transport
|
||||||
|
|||||||
@ -41,6 +41,7 @@ from libp2p.tools.async_service import (
|
|||||||
from libp2p.tools.utils import (
|
from libp2p.tools.utils import (
|
||||||
info_from_p2p_addr,
|
info_from_p2p_addr,
|
||||||
)
|
)
|
||||||
|
from libp2p.utils.paths import get_script_dir, join_paths
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -53,8 +54,8 @@ logger = logging.getLogger("kademlia-example")
|
|||||||
# Configure DHT module loggers to inherit from the parent logger
|
# Configure DHT module loggers to inherit from the parent logger
|
||||||
# This ensures all kademlia-example.* loggers use the same configuration
|
# This ensures all kademlia-example.* loggers use the same configuration
|
||||||
# Get the directory where this script is located
|
# Get the directory where this script is located
|
||||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
SCRIPT_DIR = get_script_dir(__file__)
|
||||||
SERVER_ADDR_LOG = os.path.join(SCRIPT_DIR, "server_node_addr.txt")
|
SERVER_ADDR_LOG = join_paths(SCRIPT_DIR, "server_node_addr.txt")
|
||||||
|
|
||||||
# Set the level for all child loggers
|
# Set the level for all child loggers
|
||||||
for module in [
|
for module in [
|
||||||
|
|||||||
@ -2,15 +2,20 @@ import logging
|
|||||||
|
|
||||||
from multiaddr import Multiaddr
|
from multiaddr import Multiaddr
|
||||||
from multiaddr.resolvers import DNSResolver
|
from multiaddr.resolvers import DNSResolver
|
||||||
|
import trio
|
||||||
|
|
||||||
from libp2p.abc import ID, INetworkService, PeerInfo
|
from libp2p.abc import ID, INetworkService, PeerInfo
|
||||||
from libp2p.discovery.bootstrap.utils import validate_bootstrap_addresses
|
from libp2p.discovery.bootstrap.utils import validate_bootstrap_addresses
|
||||||
from libp2p.discovery.events.peerDiscovery import peerDiscovery
|
from libp2p.discovery.events.peerDiscovery import peerDiscovery
|
||||||
|
from libp2p.network.exceptions import SwarmException
|
||||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
|
from libp2p.peer.peerstore import PERMANENT_ADDR_TTL
|
||||||
|
|
||||||
logger = logging.getLogger("libp2p.discovery.bootstrap")
|
logger = logging.getLogger("libp2p.discovery.bootstrap")
|
||||||
resolver = DNSResolver()
|
resolver = DNSResolver()
|
||||||
|
|
||||||
|
DEFAULT_CONNECTION_TIMEOUT = 10
|
||||||
|
|
||||||
|
|
||||||
class BootstrapDiscovery:
|
class BootstrapDiscovery:
|
||||||
"""
|
"""
|
||||||
@ -19,68 +24,147 @@ class BootstrapDiscovery:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, swarm: INetworkService, bootstrap_addrs: list[str]):
|
def __init__(self, swarm: INetworkService, bootstrap_addrs: list[str]):
|
||||||
|
"""
|
||||||
|
Initialize BootstrapDiscovery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
swarm: The network service (swarm) instance
|
||||||
|
bootstrap_addrs: List of bootstrap peer multiaddresses
|
||||||
|
|
||||||
|
"""
|
||||||
self.swarm = swarm
|
self.swarm = swarm
|
||||||
self.peerstore = swarm.peerstore
|
self.peerstore = swarm.peerstore
|
||||||
self.bootstrap_addrs = bootstrap_addrs or []
|
self.bootstrap_addrs = bootstrap_addrs or []
|
||||||
self.discovered_peers: set[str] = set()
|
self.discovered_peers: set[str] = set()
|
||||||
|
self.connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Process bootstrap addresses and emit peer discovery events."""
|
"""Process bootstrap addresses and emit peer discovery events in parallel."""
|
||||||
logger.debug(
|
logger.info(
|
||||||
f"Starting bootstrap discovery with "
|
f"Starting bootstrap discovery with "
|
||||||
f"{len(self.bootstrap_addrs)} bootstrap addresses"
|
f"{len(self.bootstrap_addrs)} bootstrap addresses"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Show all bootstrap addresses being processed
|
||||||
|
for i, addr in enumerate(self.bootstrap_addrs):
|
||||||
|
logger.debug(f"{i + 1}. {addr}")
|
||||||
|
|
||||||
# Validate and filter bootstrap addresses
|
# Validate and filter bootstrap addresses
|
||||||
self.bootstrap_addrs = validate_bootstrap_addresses(self.bootstrap_addrs)
|
self.bootstrap_addrs = validate_bootstrap_addresses(self.bootstrap_addrs)
|
||||||
|
logger.info(f"Valid addresses after validation: {len(self.bootstrap_addrs)}")
|
||||||
|
|
||||||
for addr_str in self.bootstrap_addrs:
|
# Use Trio nursery for PARALLEL address processing
|
||||||
try:
|
try:
|
||||||
await self._process_bootstrap_addr(addr_str)
|
async with trio.open_nursery() as nursery:
|
||||||
except Exception as e:
|
logger.debug(
|
||||||
logger.debug(f"Failed to process bootstrap address {addr_str}: {e}")
|
f"Starting {len(self.bootstrap_addrs)} parallel address "
|
||||||
|
f"processing tasks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start all bootstrap address processing tasks in parallel
|
||||||
|
for addr_str in self.bootstrap_addrs:
|
||||||
|
logger.debug(f"Starting parallel task for: {addr_str}")
|
||||||
|
nursery.start_soon(self._process_bootstrap_addr, addr_str)
|
||||||
|
|
||||||
|
# The nursery will wait for all address processing tasks to complete
|
||||||
|
logger.debug(
|
||||||
|
"Nursery active - waiting for address processing tasks to complete"
|
||||||
|
)
|
||||||
|
|
||||||
|
except trio.Cancelled:
|
||||||
|
logger.debug("Bootstrap address processing cancelled - cleaning up tasks")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Bootstrap address processing failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info("Bootstrap discovery startup complete - all tasks finished")
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Clean up bootstrap discovery resources."""
|
"""Clean up bootstrap discovery resources."""
|
||||||
logger.debug("Stopping bootstrap discovery")
|
logger.info("Stopping bootstrap discovery and cleaning up tasks")
|
||||||
|
|
||||||
|
# Clear discovered peers
|
||||||
self.discovered_peers.clear()
|
self.discovered_peers.clear()
|
||||||
|
|
||||||
|
logger.debug("Bootstrap discovery cleanup completed")
|
||||||
|
|
||||||
async def _process_bootstrap_addr(self, addr_str: str) -> None:
|
async def _process_bootstrap_addr(self, addr_str: str) -> None:
|
||||||
"""Convert string address to PeerInfo and add to peerstore."""
|
"""Convert string address to PeerInfo and add to peerstore."""
|
||||||
try:
|
try:
|
||||||
multiaddr = Multiaddr(addr_str)
|
try:
|
||||||
|
multiaddr = Multiaddr(addr_str)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Invalid multiaddr format '{addr_str}': {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.is_dns_addr(multiaddr):
|
||||||
|
resolved_addrs = await resolver.resolve(multiaddr)
|
||||||
|
if resolved_addrs is None:
|
||||||
|
logger.warning(f"DNS resolution returned None for: {addr_str}")
|
||||||
|
return
|
||||||
|
|
||||||
|
peer_id_str = multiaddr.get_peer_id()
|
||||||
|
if peer_id_str is None:
|
||||||
|
logger.warning(f"Missing peer ID in DNS address: {addr_str}")
|
||||||
|
return
|
||||||
|
peer_id = ID.from_base58(peer_id_str)
|
||||||
|
addrs = [addr for addr in resolved_addrs]
|
||||||
|
if not addrs:
|
||||||
|
logger.warning(f"No addresses resolved for DNS address: {addr_str}")
|
||||||
|
return
|
||||||
|
peer_info = PeerInfo(peer_id, addrs)
|
||||||
|
await self.add_addr(peer_info)
|
||||||
|
else:
|
||||||
|
peer_info = info_from_p2p_addr(multiaddr)
|
||||||
|
await self.add_addr(peer_info)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Invalid multiaddr format '{addr_str}': {e}")
|
logger.warning(f"Failed to process bootstrap address {addr_str}: {e}")
|
||||||
return
|
|
||||||
if self.is_dns_addr(multiaddr):
|
|
||||||
resolved_addrs = await resolver.resolve(multiaddr)
|
|
||||||
peer_id_str = multiaddr.get_peer_id()
|
|
||||||
if peer_id_str is None:
|
|
||||||
logger.warning(f"Missing peer ID in DNS address: {addr_str}")
|
|
||||||
return
|
|
||||||
peer_id = ID.from_base58(peer_id_str)
|
|
||||||
addrs = [addr for addr in resolved_addrs]
|
|
||||||
if not addrs:
|
|
||||||
logger.warning(f"No addresses resolved for DNS address: {addr_str}")
|
|
||||||
return
|
|
||||||
peer_info = PeerInfo(peer_id, addrs)
|
|
||||||
self.add_addr(peer_info)
|
|
||||||
else:
|
|
||||||
self.add_addr(info_from_p2p_addr(multiaddr))
|
|
||||||
|
|
||||||
def is_dns_addr(self, addr: Multiaddr) -> bool:
|
def is_dns_addr(self, addr: Multiaddr) -> bool:
|
||||||
"""Check if the address is a DNS address."""
|
"""Check if the address is a DNS address."""
|
||||||
return any(protocol.name == "dnsaddr" for protocol in addr.protocols())
|
return any(protocol.name == "dnsaddr" for protocol in addr.protocols())
|
||||||
|
|
||||||
def add_addr(self, peer_info: PeerInfo) -> None:
|
async def add_addr(self, peer_info: PeerInfo) -> None:
|
||||||
"""Add a peer to the peerstore and emit discovery event."""
|
"""
|
||||||
|
Add a peer to the peerstore, emit discovery event,
|
||||||
|
and attempt connection in parallel.
|
||||||
|
"""
|
||||||
|
logger.debug(
|
||||||
|
f"Adding peer {peer_info.peer_id} with {len(peer_info.addrs)} addresses"
|
||||||
|
)
|
||||||
|
|
||||||
# Skip if it's our own peer
|
# Skip if it's our own peer
|
||||||
if peer_info.peer_id == self.swarm.get_peer_id():
|
if peer_info.peer_id == self.swarm.get_peer_id():
|
||||||
logger.debug(f"Skipping own peer ID: {peer_info.peer_id}")
|
logger.debug(f"Skipping own peer ID: {peer_info.peer_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Always add addresses to peerstore (allows multiple addresses for same peer)
|
# Filter addresses to only include IPv4+TCP (only supported protocol)
|
||||||
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
|
ipv4_tcp_addrs = []
|
||||||
|
filtered_out_addrs = []
|
||||||
|
|
||||||
|
for addr in peer_info.addrs:
|
||||||
|
if self._is_ipv4_tcp_addr(addr):
|
||||||
|
ipv4_tcp_addrs.append(addr)
|
||||||
|
else:
|
||||||
|
filtered_out_addrs.append(addr)
|
||||||
|
|
||||||
|
# Log filtering results
|
||||||
|
logger.debug(
|
||||||
|
f"Address filtering for {peer_info.peer_id}: "
|
||||||
|
f"{len(ipv4_tcp_addrs)} IPv4+TCP, {len(filtered_out_addrs)} filtered"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip peer if no IPv4+TCP addresses available
|
||||||
|
if not ipv4_tcp_addrs:
|
||||||
|
logger.warning(
|
||||||
|
f"❌ No IPv4+TCP addresses for {peer_info.peer_id} - "
|
||||||
|
f"skipping connection attempts"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add only IPv4+TCP addresses to peerstore
|
||||||
|
self.peerstore.add_addrs(peer_info.peer_id, ipv4_tcp_addrs, PERMANENT_ADDR_TTL)
|
||||||
|
|
||||||
# Only emit discovery event if this is the first time we see this peer
|
# Only emit discovery event if this is the first time we see this peer
|
||||||
peer_id_str = str(peer_info.peer_id)
|
peer_id_str = str(peer_info.peer_id)
|
||||||
@ -89,6 +173,140 @@ class BootstrapDiscovery:
|
|||||||
self.discovered_peers.add(peer_id_str)
|
self.discovered_peers.add(peer_id_str)
|
||||||
# Emit peer discovery event
|
# Emit peer discovery event
|
||||||
peerDiscovery.emit_peer_discovered(peer_info)
|
peerDiscovery.emit_peer_discovered(peer_info)
|
||||||
logger.debug(f"Peer discovered: {peer_info.peer_id}")
|
logger.info(f"Peer discovered: {peer_info.peer_id}")
|
||||||
|
|
||||||
|
# Connect to peer (parallel across different bootstrap addresses)
|
||||||
|
logger.debug("Connecting to discovered peer...")
|
||||||
|
await self._connect_to_peer(peer_info.peer_id)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Additional addresses added for peer: {peer_info.peer_id}")
|
logger.debug(
|
||||||
|
f"Additional addresses added for existing peer: {peer_info.peer_id}"
|
||||||
|
)
|
||||||
|
# Even for existing peers, try to connect if not already connected
|
||||||
|
if peer_info.peer_id not in self.swarm.connections:
|
||||||
|
logger.debug("Connecting to existing peer...")
|
||||||
|
await self._connect_to_peer(peer_info.peer_id)
|
||||||
|
|
||||||
|
async def _connect_to_peer(self, peer_id: ID) -> None:
|
||||||
|
"""
|
||||||
|
Attempt to establish a connection to a peer with timeout.
|
||||||
|
|
||||||
|
Uses swarm.dial_peer to connect using addresses stored in peerstore.
|
||||||
|
Times out after self.connection_timeout seconds to prevent hanging.
|
||||||
|
"""
|
||||||
|
logger.debug(f"Connection attempt for peer: {peer_id}")
|
||||||
|
|
||||||
|
# Pre-connection validation: Check if already connected
|
||||||
|
if peer_id in self.swarm.connections:
|
||||||
|
logger.debug(
|
||||||
|
f"Already connected to {peer_id} - skipping connection attempt"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check available addresses before attempting connection
|
||||||
|
available_addrs = self.peerstore.addrs(peer_id)
|
||||||
|
logger.debug(f"Connecting to {peer_id} ({len(available_addrs)} addresses)")
|
||||||
|
|
||||||
|
if not available_addrs:
|
||||||
|
logger.error(f"❌ No addresses available for {peer_id} - cannot connect")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Record start time for connection attempt monitoring
|
||||||
|
connection_start_time = trio.current_time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with trio.move_on_after(self.connection_timeout):
|
||||||
|
# Log connection attempt
|
||||||
|
logger.debug(
|
||||||
|
f"Attempting connection to {peer_id} using "
|
||||||
|
f"{len(available_addrs)} addresses"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use swarm.dial_peer to connect using stored addresses
|
||||||
|
await self.swarm.dial_peer(peer_id)
|
||||||
|
|
||||||
|
# Calculate connection time
|
||||||
|
connection_time = trio.current_time() - connection_start_time
|
||||||
|
|
||||||
|
# Post-connection validation: Verify connection was actually established
|
||||||
|
if peer_id in self.swarm.connections:
|
||||||
|
logger.info(
|
||||||
|
f"✅ Connected to {peer_id} (took {connection_time:.2f}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Dial succeeded but connection not found for {peer_id}"
|
||||||
|
)
|
||||||
|
except trio.TooSlowError:
|
||||||
|
logger.warning(
|
||||||
|
f"❌ Connection to {peer_id} timed out after {self.connection_timeout}s"
|
||||||
|
)
|
||||||
|
except SwarmException as e:
|
||||||
|
# Calculate failed connection time
|
||||||
|
failed_connection_time = trio.current_time() - connection_start_time
|
||||||
|
|
||||||
|
# Enhanced error logging
|
||||||
|
error_msg = str(e)
|
||||||
|
if "no addresses established a successful connection" in error_msg:
|
||||||
|
logger.warning(
|
||||||
|
f"❌ Failed to connect to {peer_id} after trying all "
|
||||||
|
f"{len(available_addrs)} addresses "
|
||||||
|
f"(took {failed_connection_time:.2f}s)"
|
||||||
|
)
|
||||||
|
# Log individual address failures if this is a MultiError
|
||||||
|
if (
|
||||||
|
e.__cause__ is not None
|
||||||
|
and hasattr(e.__cause__, "exceptions")
|
||||||
|
and getattr(e.__cause__, "exceptions", None) is not None
|
||||||
|
):
|
||||||
|
exceptions_list = getattr(e.__cause__, "exceptions")
|
||||||
|
logger.debug("📋 Individual address failure details:")
|
||||||
|
for i, addr_exception in enumerate(exceptions_list, 1):
|
||||||
|
logger.debug(f"Address {i}: {addr_exception}")
|
||||||
|
# Also log the actual address that failed
|
||||||
|
if i <= len(available_addrs):
|
||||||
|
logger.debug(f"Failed address: {available_addrs[i - 1]}")
|
||||||
|
else:
|
||||||
|
logger.warning("No detailed exception information available")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"❌ Failed to connect to {peer_id}: {e} "
|
||||||
|
f"(took {failed_connection_time:.2f}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Handle unexpected errors that aren't swarm-specific
|
||||||
|
failed_connection_time = trio.current_time() - connection_start_time
|
||||||
|
logger.error(
|
||||||
|
f"❌ Unexpected error connecting to {peer_id}: "
|
||||||
|
f"{e} (took {failed_connection_time:.2f}s)"
|
||||||
|
)
|
||||||
|
# Don't re-raise to prevent killing the nursery and other parallel tasks
|
||||||
|
|
||||||
|
def _is_ipv4_tcp_addr(self, addr: Multiaddr) -> bool:
|
||||||
|
"""
|
||||||
|
Check if address is IPv4 with TCP protocol only.
|
||||||
|
|
||||||
|
Filters out IPv6, UDP, QUIC, WebSocket, and other unsupported protocols.
|
||||||
|
Only IPv4+TCP addresses are supported by the current transport.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
protocols = addr.protocols()
|
||||||
|
|
||||||
|
# Must have IPv4 protocol
|
||||||
|
has_ipv4 = any(p.name == "ip4" for p in protocols)
|
||||||
|
if not has_ipv4:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Must have TCP protocol
|
||||||
|
has_tcp = any(p.name == "tcp" for p in protocols)
|
||||||
|
if not has_tcp:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# If we can't parse the address, don't use it
|
||||||
|
return False
|
||||||
|
|||||||
@ -1,68 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from libp2p.abc import IRawConnection
|
|
||||||
from libp2p.custom_types import TProtocol
|
|
||||||
from libp2p.peer.id import ID
|
|
||||||
|
|
||||||
from .pb import noise_pb2 as noise_pb
|
|
||||||
|
|
||||||
|
|
||||||
class EarlyDataHandler(ABC):
|
|
||||||
"""Interface for handling early data during Noise handshake"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def send(
|
|
||||||
self, conn: IRawConnection, peer_id: ID
|
|
||||||
) -> noise_pb.NoiseExtensions | None:
|
|
||||||
"""Called to generate early data to send during handshake"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def received(
|
|
||||||
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
|
|
||||||
) -> None:
|
|
||||||
"""Called when early data is received during handshake"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TransportEarlyDataHandler(EarlyDataHandler):
|
|
||||||
"""Default early data handler for muxer negotiation"""
|
|
||||||
|
|
||||||
def __init__(self, supported_muxers: list[TProtocol]):
|
|
||||||
self.supported_muxers = supported_muxers
|
|
||||||
self.received_muxers: list[TProtocol] = []
|
|
||||||
|
|
||||||
async def send(
|
|
||||||
self, conn: IRawConnection, peer_id: ID
|
|
||||||
) -> noise_pb.NoiseExtensions | None:
|
|
||||||
"""Send our supported muxers list"""
|
|
||||||
if not self.supported_muxers:
|
|
||||||
return None
|
|
||||||
|
|
||||||
extensions = noise_pb.NoiseExtensions()
|
|
||||||
# Convert TProtocol to string for serialization
|
|
||||||
extensions.stream_muxers[:] = [str(muxer) for muxer in self.supported_muxers]
|
|
||||||
return extensions
|
|
||||||
|
|
||||||
async def received(
|
|
||||||
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
|
|
||||||
) -> None:
|
|
||||||
"""Store received muxers list"""
|
|
||||||
if extensions and extensions.stream_muxers:
|
|
||||||
self.received_muxers = [
|
|
||||||
TProtocol(muxer) for muxer in extensions.stream_muxers
|
|
||||||
]
|
|
||||||
|
|
||||||
def match_muxers(self, is_initiator: bool) -> TProtocol | None:
|
|
||||||
"""Find first common muxer between local and remote"""
|
|
||||||
if is_initiator:
|
|
||||||
# Initiator: find first local muxer that remote supports
|
|
||||||
for local_muxer in self.supported_muxers:
|
|
||||||
if local_muxer in self.received_muxers:
|
|
||||||
return local_muxer
|
|
||||||
else:
|
|
||||||
# Responder: find first remote muxer that we support
|
|
||||||
for remote_muxer in self.received_muxers:
|
|
||||||
if remote_muxer in self.supported_muxers:
|
|
||||||
return remote_muxer
|
|
||||||
return None
|
|
||||||
@ -30,9 +30,6 @@ from libp2p.security.secure_session import (
|
|||||||
SecureSession,
|
SecureSession,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .early_data import (
|
|
||||||
EarlyDataHandler,
|
|
||||||
)
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
HandshakeHasNotFinished,
|
HandshakeHasNotFinished,
|
||||||
InvalidSignature,
|
InvalidSignature,
|
||||||
@ -48,7 +45,6 @@ from .messages import (
|
|||||||
make_handshake_payload_sig,
|
make_handshake_payload_sig,
|
||||||
verify_handshake_payload_sig,
|
verify_handshake_payload_sig,
|
||||||
)
|
)
|
||||||
from .pb import noise_pb2 as noise_pb
|
|
||||||
|
|
||||||
|
|
||||||
class IPattern(ABC):
|
class IPattern(ABC):
|
||||||
@ -66,8 +62,7 @@ class BasePattern(IPattern):
|
|||||||
noise_static_key: PrivateKey
|
noise_static_key: PrivateKey
|
||||||
local_peer: ID
|
local_peer: ID
|
||||||
libp2p_privkey: PrivateKey
|
libp2p_privkey: PrivateKey
|
||||||
initiator_early_data_handler: EarlyDataHandler | None
|
early_data: bytes | None
|
||||||
responder_early_data_handler: EarlyDataHandler | None
|
|
||||||
|
|
||||||
def create_noise_state(self) -> NoiseState:
|
def create_noise_state(self) -> NoiseState:
|
||||||
noise_state = NoiseState.from_name(self.protocol_name)
|
noise_state = NoiseState.from_name(self.protocol_name)
|
||||||
@ -78,50 +73,11 @@ class BasePattern(IPattern):
|
|||||||
raise NoiseStateError("noise_protocol is not initialized")
|
raise NoiseStateError("noise_protocol is not initialized")
|
||||||
return noise_state
|
return noise_state
|
||||||
|
|
||||||
async def make_handshake_payload(
|
def make_handshake_payload(self) -> NoiseHandshakePayload:
|
||||||
self, conn: IRawConnection, peer_id: ID, is_initiator: bool
|
|
||||||
) -> NoiseHandshakePayload:
|
|
||||||
signature = make_handshake_payload_sig(
|
signature = make_handshake_payload_sig(
|
||||||
self.libp2p_privkey, self.noise_static_key.get_public_key()
|
self.libp2p_privkey, self.noise_static_key.get_public_key()
|
||||||
)
|
)
|
||||||
|
return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature)
|
||||||
# NEW: Get early data from appropriate handler
|
|
||||||
extensions = None
|
|
||||||
if is_initiator and self.initiator_early_data_handler:
|
|
||||||
extensions = await self.initiator_early_data_handler.send(conn, peer_id)
|
|
||||||
elif not is_initiator and self.responder_early_data_handler:
|
|
||||||
extensions = await self.responder_early_data_handler.send(conn, peer_id)
|
|
||||||
|
|
||||||
# NEW: Serialize extensions into early_data field
|
|
||||||
early_data = None
|
|
||||||
if extensions:
|
|
||||||
early_data = extensions.SerializeToString()
|
|
||||||
|
|
||||||
return NoiseHandshakePayload(
|
|
||||||
self.libp2p_privkey.get_public_key(),
|
|
||||||
signature,
|
|
||||||
early_data, # ← This is the key addition
|
|
||||||
)
|
|
||||||
|
|
||||||
async def handle_received_payload(
|
|
||||||
self, conn: IRawConnection, payload: NoiseHandshakePayload, is_initiator: bool
|
|
||||||
) -> None:
|
|
||||||
"""Process early data from received payload"""
|
|
||||||
if not payload.early_data:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Deserialize the NoiseExtensions from early_data field
|
|
||||||
try:
|
|
||||||
extensions = noise_pb.NoiseExtensions.FromString(payload.early_data)
|
|
||||||
except Exception:
|
|
||||||
# Invalid extensions, ignore silently
|
|
||||||
return
|
|
||||||
|
|
||||||
# Pass to appropriate handler
|
|
||||||
if is_initiator and self.initiator_early_data_handler:
|
|
||||||
await self.initiator_early_data_handler.received(conn, extensions)
|
|
||||||
elif not is_initiator and self.responder_early_data_handler:
|
|
||||||
await self.responder_early_data_handler.received(conn, extensions)
|
|
||||||
|
|
||||||
|
|
||||||
class PatternXX(BasePattern):
|
class PatternXX(BasePattern):
|
||||||
@ -130,15 +86,13 @@ class PatternXX(BasePattern):
|
|||||||
local_peer: ID,
|
local_peer: ID,
|
||||||
libp2p_privkey: PrivateKey,
|
libp2p_privkey: PrivateKey,
|
||||||
noise_static_key: PrivateKey,
|
noise_static_key: PrivateKey,
|
||||||
initiator_early_data_handler: EarlyDataHandler | None,
|
early_data: bytes | None = None,
|
||||||
responder_early_data_handler: EarlyDataHandler | None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
|
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
|
||||||
self.local_peer = local_peer
|
self.local_peer = local_peer
|
||||||
self.libp2p_privkey = libp2p_privkey
|
self.libp2p_privkey = libp2p_privkey
|
||||||
self.noise_static_key = noise_static_key
|
self.noise_static_key = noise_static_key
|
||||||
self.initiator_early_data_handler = initiator_early_data_handler
|
self.early_data = early_data
|
||||||
self.responder_early_data_handler = responder_early_data_handler
|
|
||||||
|
|
||||||
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||||
noise_state = self.create_noise_state()
|
noise_state = self.create_noise_state()
|
||||||
@ -152,23 +106,18 @@ class PatternXX(BasePattern):
|
|||||||
|
|
||||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||||
|
|
||||||
# 1. Consume msg#1 (just empty bytes)
|
# Consume msg#1.
|
||||||
await read_writer.read_msg()
|
await read_writer.read_msg()
|
||||||
|
|
||||||
# 2. Send msg#2 with our payload INCLUDING EARLY DATA
|
# Send msg#2, which should include our handshake payload.
|
||||||
our_payload = await self.make_handshake_payload(
|
our_payload = self.make_handshake_payload()
|
||||||
conn,
|
|
||||||
self.local_peer, # We send our own peer ID in responder role
|
|
||||||
is_initiator=False,
|
|
||||||
)
|
|
||||||
msg_2 = our_payload.serialize()
|
msg_2 = our_payload.serialize()
|
||||||
await read_writer.write_msg(msg_2)
|
await read_writer.write_msg(msg_2)
|
||||||
|
|
||||||
# 3. Receive msg#3
|
# Receive and consume msg#3.
|
||||||
msg_3 = await read_writer.read_msg()
|
msg_3 = await read_writer.read_msg()
|
||||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
||||||
|
|
||||||
# Extract remote pubkey from noise handshake state
|
|
||||||
if handshake_state.rs is None:
|
if handshake_state.rs is None:
|
||||||
raise NoiseStateError(
|
raise NoiseStateError(
|
||||||
"something is wrong in the underlying noise `handshake_state`: "
|
"something is wrong in the underlying noise `handshake_state`: "
|
||||||
@ -177,31 +126,14 @@ class PatternXX(BasePattern):
|
|||||||
)
|
)
|
||||||
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
||||||
|
|
||||||
# 4. Verify signature (unchanged)
|
|
||||||
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||||
raise InvalidSignature
|
raise InvalidSignature
|
||||||
|
|
||||||
# NEW: Process early data from msg#3 AFTER signature verification
|
|
||||||
await self.handle_received_payload(
|
|
||||||
conn, peer_handshake_payload, is_initiator=False
|
|
||||||
)
|
|
||||||
|
|
||||||
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||||
|
|
||||||
if not noise_state.handshake_finished:
|
if not noise_state.handshake_finished:
|
||||||
raise HandshakeHasNotFinished(
|
raise HandshakeHasNotFinished(
|
||||||
"handshake is done but it is not marked as finished in `noise_state`"
|
"handshake is done but it is not marked as finished in `noise_state`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NEW: Get negotiated muxer for connection state
|
|
||||||
# negotiated_muxer = None
|
|
||||||
if self.responder_early_data_handler and hasattr(
|
|
||||||
self.responder_early_data_handler, "match_muxers"
|
|
||||||
):
|
|
||||||
# negotiated_muxer =
|
|
||||||
# self.responder_early_data_handler.match_muxers(is_initiator=False)
|
|
||||||
pass
|
|
||||||
|
|
||||||
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
||||||
return SecureSession(
|
return SecureSession(
|
||||||
local_peer=self.local_peer,
|
local_peer=self.local_peer,
|
||||||
@ -210,8 +142,6 @@ class PatternXX(BasePattern):
|
|||||||
remote_permanent_pubkey=remote_pubkey,
|
remote_permanent_pubkey=remote_pubkey,
|
||||||
is_initiator=False,
|
is_initiator=False,
|
||||||
conn=transport_read_writer,
|
conn=transport_read_writer,
|
||||||
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
|
|
||||||
# For now, store it in connection metadata or similar
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handshake_outbound(
|
async def handshake_outbound(
|
||||||
@ -228,27 +158,24 @@ class PatternXX(BasePattern):
|
|||||||
if handshake_state is None:
|
if handshake_state is None:
|
||||||
raise NoiseStateError("Handshake state is not initialized")
|
raise NoiseStateError("Handshake state is not initialized")
|
||||||
|
|
||||||
# 1. Send msg#1 (empty) - no early data possible in XX pattern
|
# Send msg#1, which is *not* encrypted.
|
||||||
msg_1 = b""
|
msg_1 = b""
|
||||||
await read_writer.write_msg(msg_1)
|
await read_writer.write_msg(msg_1)
|
||||||
|
|
||||||
# 2. Read msg#2 from responder
|
# Read msg#2 from the remote, which contains the public key of the peer.
|
||||||
msg_2 = await read_writer.read_msg()
|
msg_2 = await read_writer.read_msg()
|
||||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
||||||
|
|
||||||
# Extract remote pubkey from noise handshake state
|
|
||||||
if handshake_state.rs is None:
|
if handshake_state.rs is None:
|
||||||
raise NoiseStateError(
|
raise NoiseStateError(
|
||||||
"something is wrong in the underlying noise `handshake_state`: "
|
"something is wrong in the underlying noise `handshake_state`: "
|
||||||
"we received and consumed msg#2, which should have included the "
|
"we received and consumed msg#3, which should have included the "
|
||||||
"remote static public key, but it is not present in the handshake_state"
|
"remote static public key, but it is not present in the handshake_state"
|
||||||
)
|
)
|
||||||
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
||||||
|
|
||||||
# Verify signature BEFORE processing early data (security)
|
|
||||||
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||||
raise InvalidSignature
|
raise InvalidSignature
|
||||||
|
|
||||||
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||||
if remote_peer_id_from_pubkey != remote_peer:
|
if remote_peer_id_from_pubkey != remote_peer:
|
||||||
raise PeerIDMismatchesPubkey(
|
raise PeerIDMismatchesPubkey(
|
||||||
@ -257,15 +184,8 @@ class PatternXX(BasePattern):
|
|||||||
f"remote_peer_id_from_pubkey={remote_peer_id_from_pubkey}"
|
f"remote_peer_id_from_pubkey={remote_peer_id_from_pubkey}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NEW: Process early data from msg#2 AFTER verification
|
# Send msg#3, which includes our encrypted payload and our noise static key.
|
||||||
await self.handle_received_payload(
|
our_payload = self.make_handshake_payload()
|
||||||
conn, peer_handshake_payload, is_initiator=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Send msg#3 with our payload INCLUDING EARLY DATA
|
|
||||||
our_payload = await self.make_handshake_payload(
|
|
||||||
conn, remote_peer, is_initiator=True
|
|
||||||
)
|
|
||||||
msg_3 = our_payload.serialize()
|
msg_3 = our_payload.serialize()
|
||||||
await read_writer.write_msg(msg_3)
|
await read_writer.write_msg(msg_3)
|
||||||
|
|
||||||
@ -273,16 +193,6 @@ class PatternXX(BasePattern):
|
|||||||
raise HandshakeHasNotFinished(
|
raise HandshakeHasNotFinished(
|
||||||
"handshake is done but it is not marked as finished in `noise_state`"
|
"handshake is done but it is not marked as finished in `noise_state`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NEW: Get negotiated muxer
|
|
||||||
# negotiated_muxer = None
|
|
||||||
if self.initiator_early_data_handler and hasattr(
|
|
||||||
self.initiator_early_data_handler, "match_muxers"
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
# negotiated_muxer =
|
|
||||||
# self.initiator_early_data_handler.match_muxers(is_initiator=True)
|
|
||||||
|
|
||||||
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
||||||
return SecureSession(
|
return SecureSession(
|
||||||
local_peer=self.local_peer,
|
local_peer=self.local_peer,
|
||||||
@ -291,8 +201,6 @@ class PatternXX(BasePattern):
|
|||||||
remote_permanent_pubkey=remote_pubkey,
|
remote_permanent_pubkey=remote_pubkey,
|
||||||
is_initiator=True,
|
is_initiator=True,
|
||||||
conn=transport_read_writer,
|
conn=transport_read_writer,
|
||||||
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
|
|
||||||
# For now, store it in connection metadata or similar
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -1,13 +1,8 @@
|
|||||||
syntax = "proto2";
|
syntax = "proto3";
|
||||||
package pb;
|
package pb;
|
||||||
|
|
||||||
message NoiseExtensions {
|
|
||||||
repeated bytes webtransport_certhashes = 1;
|
|
||||||
repeated string stream_muxers = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message NoiseHandshakePayload {
|
message NoiseHandshakePayload {
|
||||||
optional bytes identity_key = 1;
|
bytes identity_key = 1;
|
||||||
optional bytes identity_sig = 2;
|
bytes identity_sig = 2;
|
||||||
optional bytes data = 3;
|
bytes data = 3;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,15 +13,13 @@ _sym_db = _symbol_database.Default()
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"I\n\x0fNoiseExtensions\x12\x1f\n\x17webtransport_certhashes\x18\x01 \x03(\x0c\x12\x15\n\rstream_muxers\x18\x02 \x03(\t\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c')
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x62\x06proto3')
|
||||||
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.security.noise.pb.noise_pb2', globals())
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.security.noise.pb.noise_pb2', globals())
|
||||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||||
|
|
||||||
DESCRIPTOR._options = None
|
DESCRIPTOR._options = None
|
||||||
_NOISEEXTENSIONS._serialized_start=44
|
_NOISEHANDSHAKEPAYLOAD._serialized_start=44
|
||||||
_NOISEEXTENSIONS._serialized_end=117
|
_NOISEHANDSHAKEPAYLOAD._serialized_end=125
|
||||||
_NOISEHANDSHAKEPAYLOAD._serialized_start=119
|
|
||||||
_NOISEHANDSHAKEPAYLOAD._serialized_end=200
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
|||||||
@ -4,34 +4,12 @@ isort:skip_file
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
import collections.abc
|
|
||||||
import google.protobuf.descriptor
|
import google.protobuf.descriptor
|
||||||
import google.protobuf.internal.containers
|
|
||||||
import google.protobuf.message
|
import google.protobuf.message
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||||
|
|
||||||
@typing.final
|
|
||||||
class NoiseExtensions(google.protobuf.message.Message):
|
|
||||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
||||||
|
|
||||||
WEBTRANSPORT_CERTHASHES_FIELD_NUMBER: builtins.int
|
|
||||||
STREAM_MUXERS_FIELD_NUMBER: builtins.int
|
|
||||||
@property
|
|
||||||
def webtransport_certhashes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
|
|
||||||
@property
|
|
||||||
def stream_muxers(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
webtransport_certhashes: collections.abc.Iterable[builtins.bytes] | None = ...,
|
|
||||||
stream_muxers: collections.abc.Iterable[builtins.str] | None = ...,
|
|
||||||
) -> None: ...
|
|
||||||
def ClearField(self, field_name: typing.Literal["stream_muxers", b"stream_muxers", "webtransport_certhashes", b"webtransport_certhashes"]) -> None: ...
|
|
||||||
|
|
||||||
global___NoiseExtensions = NoiseExtensions
|
|
||||||
|
|
||||||
@typing.final
|
@typing.final
|
||||||
class NoiseHandshakePayload(google.protobuf.message.Message):
|
class NoiseHandshakePayload(google.protobuf.message.Message):
|
||||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||||
@ -45,11 +23,10 @@ class NoiseHandshakePayload(google.protobuf.message.Message):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
identity_key: builtins.bytes | None = ...,
|
identity_key: builtins.bytes = ...,
|
||||||
identity_sig: builtins.bytes | None = ...,
|
identity_sig: builtins.bytes = ...,
|
||||||
data: builtins.bytes | None = ...,
|
data: builtins.bytes = ...,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def HasField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> builtins.bool: ...
|
|
||||||
def ClearField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> None: ...
|
def ClearField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> None: ...
|
||||||
|
|
||||||
global___NoiseHandshakePayload = NoiseHandshakePayload
|
global___NoiseHandshakePayload = NoiseHandshakePayload
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from libp2p.peer.id import (
|
|||||||
ID,
|
ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .early_data import EarlyDataHandler, TransportEarlyDataHandler
|
|
||||||
from .patterns import (
|
from .patterns import (
|
||||||
IPattern,
|
IPattern,
|
||||||
PatternXX,
|
PatternXX,
|
||||||
@ -27,40 +26,35 @@ class Transport(ISecureTransport):
|
|||||||
libp2p_privkey: PrivateKey
|
libp2p_privkey: PrivateKey
|
||||||
noise_privkey: PrivateKey
|
noise_privkey: PrivateKey
|
||||||
local_peer: ID
|
local_peer: ID
|
||||||
supported_muxers: list[TProtocol]
|
early_data: bytes | None
|
||||||
initiator_early_data_handler: EarlyDataHandler | None
|
with_noise_pipes: bool
|
||||||
responder_early_data_handler: EarlyDataHandler | None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
libp2p_keypair: KeyPair,
|
libp2p_keypair: KeyPair,
|
||||||
noise_privkey: PrivateKey,
|
noise_privkey: PrivateKey,
|
||||||
supported_muxers: list[TProtocol] | None = None,
|
early_data: bytes | None = None,
|
||||||
initiator_handler: EarlyDataHandler | None = None,
|
with_noise_pipes: bool = False,
|
||||||
responder_handler: EarlyDataHandler | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.libp2p_privkey = libp2p_keypair.private_key
|
self.libp2p_privkey = libp2p_keypair.private_key
|
||||||
self.noise_privkey = noise_privkey
|
self.noise_privkey = noise_privkey
|
||||||
self.local_peer = ID.from_pubkey(libp2p_keypair.public_key)
|
self.local_peer = ID.from_pubkey(libp2p_keypair.public_key)
|
||||||
self.supported_muxers = supported_muxers or []
|
self.early_data = early_data
|
||||||
|
self.with_noise_pipes = with_noise_pipes
|
||||||
|
|
||||||
# Create default handlers for muxer negotiation if none provided
|
if self.with_noise_pipes:
|
||||||
if initiator_handler is None and self.supported_muxers:
|
raise NotImplementedError
|
||||||
initiator_handler = TransportEarlyDataHandler(self.supported_muxers)
|
|
||||||
if responder_handler is None and self.supported_muxers:
|
|
||||||
responder_handler = TransportEarlyDataHandler(self.supported_muxers)
|
|
||||||
|
|
||||||
self.initiator_early_data_handler = initiator_handler
|
|
||||||
self.responder_early_data_handler = responder_handler
|
|
||||||
|
|
||||||
def get_pattern(self) -> IPattern:
|
def get_pattern(self) -> IPattern:
|
||||||
return PatternXX(
|
if self.with_noise_pipes:
|
||||||
self.local_peer,
|
raise NotImplementedError
|
||||||
self.libp2p_privkey,
|
else:
|
||||||
self.noise_privkey,
|
return PatternXX(
|
||||||
self.initiator_early_data_handler,
|
self.local_peer,
|
||||||
self.responder_early_data_handler,
|
self.libp2p_privkey,
|
||||||
)
|
self.noise_privkey,
|
||||||
|
self.early_data,
|
||||||
|
)
|
||||||
|
|
||||||
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
|
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||||
pattern = self.get_pattern()
|
pattern = self.get_pattern()
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import time
|
||||||
from types import (
|
from types import (
|
||||||
TracebackType,
|
TracebackType,
|
||||||
)
|
)
|
||||||
@ -100,6 +101,12 @@ class ReadWriteLock:
|
|||||||
self.release_write()
|
self.release_write()
|
||||||
|
|
||||||
|
|
||||||
|
class MplexStreamTimeout(Exception):
|
||||||
|
"""Raised when a stream operation exceeds its deadline."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MplexStream(IMuxedStream):
|
class MplexStream(IMuxedStream):
|
||||||
"""
|
"""
|
||||||
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
||||||
@ -111,8 +118,8 @@ class MplexStream(IMuxedStream):
|
|||||||
# class of IMuxedConn. Ignoring this type assignment should not pose
|
# class of IMuxedConn. Ignoring this type assignment should not pose
|
||||||
# any risk.
|
# any risk.
|
||||||
muxed_conn: "Mplex" # type: ignore[assignment]
|
muxed_conn: "Mplex" # type: ignore[assignment]
|
||||||
read_deadline: int | None
|
read_deadline: float | None
|
||||||
write_deadline: int | None
|
write_deadline: float | None
|
||||||
|
|
||||||
rw_lock: ReadWriteLock
|
rw_lock: ReadWriteLock
|
||||||
close_lock: trio.Lock
|
close_lock: trio.Lock
|
||||||
@ -156,6 +163,30 @@ class MplexStream(IMuxedStream):
|
|||||||
def is_initiator(self) -> bool:
|
def is_initiator(self) -> bool:
|
||||||
return self.stream_id.is_initiator
|
return self.stream_id.is_initiator
|
||||||
|
|
||||||
|
def _check_read_deadline(self) -> None:
|
||||||
|
"""Check if read deadline has expired and raise timeout if needed."""
|
||||||
|
if self.read_deadline is not None and time.time() > self.read_deadline:
|
||||||
|
raise MplexStreamTimeout("Read operation exceeded deadline")
|
||||||
|
|
||||||
|
def _check_write_deadline(self) -> None:
|
||||||
|
"""Check if write deadline has expired and raise timeout if needed."""
|
||||||
|
if self.write_deadline is not None and time.time() > self.write_deadline:
|
||||||
|
raise MplexStreamTimeout("Write operation exceeded deadline")
|
||||||
|
|
||||||
|
def _get_read_timeout(self) -> float | None:
|
||||||
|
"""Calculate remaining time until read deadline."""
|
||||||
|
if self.read_deadline is None:
|
||||||
|
return None
|
||||||
|
remaining = self.read_deadline - time.time()
|
||||||
|
return max(0.0, remaining) if remaining > 0 else 0
|
||||||
|
|
||||||
|
def _get_write_timeout(self) -> float | None:
|
||||||
|
"""Calculate remaining time until write deadline."""
|
||||||
|
if self.write_deadline is None:
|
||||||
|
return None
|
||||||
|
remaining = self.write_deadline - time.time()
|
||||||
|
return max(0.0, remaining) if remaining > 0 else 0
|
||||||
|
|
||||||
async def _read_until_eof(self) -> bytes:
|
async def _read_until_eof(self) -> bytes:
|
||||||
async for data in self.incoming_data_channel:
|
async for data in self.incoming_data_channel:
|
||||||
self._buf.extend(data)
|
self._buf.extend(data)
|
||||||
@ -182,6 +213,9 @@ class MplexStream(IMuxedStream):
|
|||||||
:param n: number of bytes to read
|
:param n: number of bytes to read
|
||||||
:return: bytes actually read
|
:return: bytes actually read
|
||||||
"""
|
"""
|
||||||
|
# check deadline before starting
|
||||||
|
self._check_read_deadline()
|
||||||
|
|
||||||
async with self.rw_lock.read_lock():
|
async with self.rw_lock.read_lock():
|
||||||
if n is not None and n < 0:
|
if n is not None and n < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -192,8 +226,13 @@ class MplexStream(IMuxedStream):
|
|||||||
raise MplexStreamReset
|
raise MplexStreamReset
|
||||||
if n is None:
|
if n is None:
|
||||||
return await self._read_until_eof()
|
return await self._read_until_eof()
|
||||||
|
|
||||||
|
# check deadline again before potentially blocking operation
|
||||||
|
self._check_read_deadline()
|
||||||
|
|
||||||
if len(self._buf) == 0:
|
if len(self._buf) == 0:
|
||||||
data: bytes
|
data: bytes
|
||||||
|
timeout = self._get_read_timeout()
|
||||||
# Peek whether there is data available. If yes, we just read until
|
# Peek whether there is data available. If yes, we just read until
|
||||||
# there is no data, then return.
|
# there is no data, then return.
|
||||||
try:
|
try:
|
||||||
@ -207,6 +246,20 @@ class MplexStream(IMuxedStream):
|
|||||||
try:
|
try:
|
||||||
data = await self.incoming_data_channel.receive()
|
data = await self.incoming_data_channel.receive()
|
||||||
self._buf.extend(data)
|
self._buf.extend(data)
|
||||||
|
if timeout is not None and timeout <= 0:
|
||||||
|
raise MplexStreamTimeout(
|
||||||
|
"Read deadline exceeded while waiting for data"
|
||||||
|
)
|
||||||
|
|
||||||
|
if timeout is not None:
|
||||||
|
with trio.fail_after(timeout):
|
||||||
|
data = await self.incoming_data_channel.receive()
|
||||||
|
else:
|
||||||
|
data = await self.incoming_data_channel.receive()
|
||||||
|
|
||||||
|
self._buf.extend(data)
|
||||||
|
except trio.TooSlowError:
|
||||||
|
raise MplexStreamTimeout("Read operation timed out")
|
||||||
except trio.EndOfChannel:
|
except trio.EndOfChannel:
|
||||||
if self.event_reset.is_set():
|
if self.event_reset.is_set():
|
||||||
raise MplexStreamReset
|
raise MplexStreamReset
|
||||||
@ -226,15 +279,43 @@ class MplexStream(IMuxedStream):
|
|||||||
self._buf = self._buf[len(payload) :]
|
self._buf = self._buf[len(payload) :]
|
||||||
return bytes(payload)
|
return bytes(payload)
|
||||||
|
|
||||||
|
async def _read_until_eof_with_timeout(self) -> bytes:
|
||||||
|
"""Read until EOF with timeout support."""
|
||||||
|
timeout = self._get_read_timeout()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if timeout is not None:
|
||||||
|
with trio.fail_after(timeout):
|
||||||
|
async for data in self.incoming_data_channel:
|
||||||
|
self._buf.extend(data)
|
||||||
|
else:
|
||||||
|
async for data in self.incoming_data_channel:
|
||||||
|
self._buf.extend(data)
|
||||||
|
except trio.TooSlowError:
|
||||||
|
raise MplexStreamTimeout("Read until EOF operation timed out")
|
||||||
|
|
||||||
|
payload = self._buf
|
||||||
|
self._buf = self._buf[len(payload) :]
|
||||||
|
return bytes(payload)
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
"""
|
"""
|
||||||
Write to stream.
|
Write to stream.
|
||||||
|
|
||||||
:return: number of bytes written
|
:return: number of bytes written
|
||||||
"""
|
"""
|
||||||
|
# Check deadline before starting
|
||||||
|
self._check_write_deadline()
|
||||||
|
|
||||||
async with self.rw_lock.write_lock():
|
async with self.rw_lock.write_lock():
|
||||||
if self.event_local_closed.is_set():
|
if self.event_local_closed.is_set():
|
||||||
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
|
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
|
||||||
|
|
||||||
|
# Check deadline again after acquiring lock
|
||||||
|
timeout = self._get_write_timeout()
|
||||||
|
if timeout is not None and timeout <= 0:
|
||||||
|
raise MplexStreamTimeout("Write deadline exceeded")
|
||||||
|
|
||||||
flag = (
|
flag = (
|
||||||
HeaderTags.MessageInitiator
|
HeaderTags.MessageInitiator
|
||||||
if self.is_initiator
|
if self.is_initiator
|
||||||
@ -315,8 +396,9 @@ class MplexStream(IMuxedStream):
|
|||||||
|
|
||||||
:return: True if successful
|
:return: True if successful
|
||||||
"""
|
"""
|
||||||
self.read_deadline = ttl
|
deadline = time.time() + ttl
|
||||||
self.write_deadline = ttl
|
self.read_deadline = deadline
|
||||||
|
self.write_deadline = deadline
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def set_read_deadline(self, ttl: int) -> bool:
|
def set_read_deadline(self, ttl: int) -> bool:
|
||||||
@ -325,7 +407,7 @@ class MplexStream(IMuxedStream):
|
|||||||
|
|
||||||
:return: True if successful
|
:return: True if successful
|
||||||
"""
|
"""
|
||||||
self.read_deadline = ttl
|
self.read_deadline = time.time() + ttl
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def set_write_deadline(self, ttl: int) -> bool:
|
def set_write_deadline(self, ttl: int) -> bool:
|
||||||
@ -334,7 +416,7 @@ class MplexStream(IMuxedStream):
|
|||||||
|
|
||||||
:return: True if successful
|
:return: True if successful
|
||||||
"""
|
"""
|
||||||
self.write_deadline = ttl
|
self.write_deadline = ttl + time.time()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_remote_address(self) -> tuple[str, int] | None:
|
def get_remote_address(self) -> tuple[str, int] | None:
|
||||||
|
|||||||
@ -1,7 +1,4 @@
|
|||||||
import atexit
|
import atexit
|
||||||
from datetime import (
|
|
||||||
datetime,
|
|
||||||
)
|
|
||||||
import logging
|
import logging
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
import os
|
import os
|
||||||
@ -21,6 +18,9 @@ log_queue: "queue.Queue[Any]" = queue.Queue()
|
|||||||
# Store the current listener to stop it on exit
|
# Store the current listener to stop it on exit
|
||||||
_current_listener: logging.handlers.QueueListener | None = None
|
_current_listener: logging.handlers.QueueListener | None = None
|
||||||
|
|
||||||
|
# Store the handlers for proper cleanup
|
||||||
|
_current_handlers: list[logging.Handler] = []
|
||||||
|
|
||||||
# Event to track when the listener is ready
|
# Event to track when the listener is ready
|
||||||
_listener_ready = threading.Event()
|
_listener_ready = threading.Event()
|
||||||
|
|
||||||
@ -95,7 +95,7 @@ def setup_logging() -> None:
|
|||||||
- Child loggers inherit their parent's level unless explicitly set
|
- Child loggers inherit their parent's level unless explicitly set
|
||||||
- The root libp2p logger controls the default level
|
- The root libp2p logger controls the default level
|
||||||
"""
|
"""
|
||||||
global _current_listener, _listener_ready
|
global _current_listener, _listener_ready, _current_handlers
|
||||||
|
|
||||||
# Reset the event
|
# Reset the event
|
||||||
_listener_ready.clear()
|
_listener_ready.clear()
|
||||||
@ -105,6 +105,12 @@ def setup_logging() -> None:
|
|||||||
_current_listener.stop()
|
_current_listener.stop()
|
||||||
_current_listener = None
|
_current_listener = None
|
||||||
|
|
||||||
|
# Close and clear existing handlers
|
||||||
|
for handler in _current_handlers:
|
||||||
|
if isinstance(handler, logging.FileHandler):
|
||||||
|
handler.close()
|
||||||
|
_current_handlers.clear()
|
||||||
|
|
||||||
# Get the log level from environment variable
|
# Get the log level from environment variable
|
||||||
debug_str = os.environ.get("LIBP2P_DEBUG", "")
|
debug_str = os.environ.get("LIBP2P_DEBUG", "")
|
||||||
|
|
||||||
@ -148,13 +154,10 @@ def setup_logging() -> None:
|
|||||||
log_path = Path(log_file)
|
log_path = Path(log_file)
|
||||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
else:
|
else:
|
||||||
# Default log file with timestamp and unique identifier
|
# Use cross-platform temp file creation
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
from libp2p.utils.paths import create_temp_file
|
||||||
unique_id = os.urandom(4).hex() # Add a unique identifier to prevent collisions
|
|
||||||
if os.name == "nt": # Windows
|
log_file = str(create_temp_file(prefix="py-libp2p_", suffix=".log"))
|
||||||
log_file = f"C:\\Windows\\Temp\\py-libp2p_{timestamp}_{unique_id}.log"
|
|
||||||
else: # Unix-like
|
|
||||||
log_file = f"/tmp/py-libp2p_{timestamp}_{unique_id}.log"
|
|
||||||
|
|
||||||
# Print the log file path so users know where to find it
|
# Print the log file path so users know where to find it
|
||||||
print(f"Logging to: {log_file}", file=sys.stderr)
|
print(f"Logging to: {log_file}", file=sys.stderr)
|
||||||
@ -195,6 +198,9 @@ def setup_logging() -> None:
|
|||||||
logger.setLevel(level)
|
logger.setLevel(level)
|
||||||
logger.propagate = False # Prevent message duplication
|
logger.propagate = False # Prevent message duplication
|
||||||
|
|
||||||
|
# Store handlers globally for cleanup
|
||||||
|
_current_handlers.extend(handlers)
|
||||||
|
|
||||||
# Start the listener AFTER configuring all loggers
|
# Start the listener AFTER configuring all loggers
|
||||||
_current_listener = logging.handlers.QueueListener(
|
_current_listener = logging.handlers.QueueListener(
|
||||||
log_queue, *handlers, respect_handler_level=True
|
log_queue, *handlers, respect_handler_level=True
|
||||||
@ -209,7 +215,13 @@ def setup_logging() -> None:
|
|||||||
@atexit.register
|
@atexit.register
|
||||||
def cleanup_logging() -> None:
|
def cleanup_logging() -> None:
|
||||||
"""Clean up logging resources on exit."""
|
"""Clean up logging resources on exit."""
|
||||||
global _current_listener
|
global _current_listener, _current_handlers
|
||||||
if _current_listener is not None:
|
if _current_listener is not None:
|
||||||
_current_listener.stop()
|
_current_listener.stop()
|
||||||
_current_listener = None
|
_current_listener = None
|
||||||
|
|
||||||
|
# Close all file handlers to ensure proper cleanup on Windows
|
||||||
|
for handler in _current_handlers:
|
||||||
|
if isinstance(handler, logging.FileHandler):
|
||||||
|
handler.close()
|
||||||
|
_current_handlers.clear()
|
||||||
|
|||||||
267
libp2p/utils/paths.py
Normal file
267
libp2p/utils/paths.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
"""
|
||||||
|
Cross-platform path utilities for py-libp2p.
|
||||||
|
|
||||||
|
This module provides standardized path operations to ensure consistent
|
||||||
|
behavior across Windows, macOS, and Linux platforms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
PathLike = Union[str, Path]
|
||||||
|
|
||||||
|
|
||||||
|
def get_temp_dir() -> Path:
|
||||||
|
"""
|
||||||
|
Get cross-platform temporary directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Platform-specific temporary directory path
|
||||||
|
|
||||||
|
"""
|
||||||
|
return Path(tempfile.gettempdir())
|
||||||
|
|
||||||
|
|
||||||
|
def get_project_root() -> Path:
|
||||||
|
"""
|
||||||
|
Get the project root directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path to the py-libp2p project root
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Navigate from libp2p/utils/paths.py to project root
|
||||||
|
return Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def join_paths(*parts: PathLike) -> Path:
|
||||||
|
"""
|
||||||
|
Cross-platform path joining.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*parts: Path components to join
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Joined path using platform-appropriate separator
|
||||||
|
|
||||||
|
"""
|
||||||
|
return Path(*parts)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_dir_exists(path: PathLike) -> Path:
|
||||||
|
"""
|
||||||
|
Ensure directory exists, create if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Directory path to ensure exists
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path object for the directory
|
||||||
|
|
||||||
|
"""
|
||||||
|
path_obj = Path(path)
|
||||||
|
path_obj.mkdir(parents=True, exist_ok=True)
|
||||||
|
return path_obj
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_dir() -> Path:
|
||||||
|
"""
|
||||||
|
Get user config directory (cross-platform).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Platform-specific config directory
|
||||||
|
|
||||||
|
"""
|
||||||
|
if os.name == "nt": # Windows
|
||||||
|
appdata = os.environ.get("APPDATA", "")
|
||||||
|
if appdata:
|
||||||
|
return Path(appdata) / "py-libp2p"
|
||||||
|
else:
|
||||||
|
# Fallback to user home directory
|
||||||
|
return Path.home() / "AppData" / "Roaming" / "py-libp2p"
|
||||||
|
else: # Unix-like (Linux, macOS)
|
||||||
|
return Path.home() / ".config" / "py-libp2p"
|
||||||
|
|
||||||
|
|
||||||
|
def get_script_dir(script_path: PathLike | None = None) -> Path:
|
||||||
|
"""
|
||||||
|
Get the directory containing a script file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
script_path: Path to the script file. If None, uses __file__
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Directory containing the script
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If script path cannot be determined
|
||||||
|
|
||||||
|
"""
|
||||||
|
if script_path is None:
|
||||||
|
# This will be the directory of the calling script
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
if frame and frame.f_back:
|
||||||
|
script_path = frame.f_back.f_globals.get("__file__")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Could not determine script path")
|
||||||
|
|
||||||
|
if script_path is None:
|
||||||
|
raise RuntimeError("Script path is None")
|
||||||
|
|
||||||
|
return Path(script_path).parent.absolute()
|
||||||
|
|
||||||
|
|
||||||
|
def create_temp_file(prefix: str = "py-libp2p_", suffix: str = ".log") -> Path:
|
||||||
|
"""
|
||||||
|
Create a temporary file with a unique name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: File name prefix
|
||||||
|
suffix: File name suffix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path to the created temporary file
|
||||||
|
|
||||||
|
"""
|
||||||
|
temp_dir = get_temp_dir()
|
||||||
|
# Create a unique filename using timestamp and random bytes
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
|
||||||
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
microseconds = f"{time.time() % 1:.6f}"[2:] # Get microseconds as string
|
||||||
|
unique_id = secrets.token_hex(4)
|
||||||
|
filename = f"{prefix}{timestamp}_{microseconds}_{unique_id}{suffix}"
|
||||||
|
|
||||||
|
temp_file = temp_dir / filename
|
||||||
|
# Create the file by touching it
|
||||||
|
temp_file.touch()
|
||||||
|
return temp_file
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_relative_path(base_path: PathLike, relative_path: PathLike) -> Path:
|
||||||
|
"""
|
||||||
|
Resolve a relative path from a base path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_path: Base directory path
|
||||||
|
relative_path: Relative path to resolve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Resolved absolute path
|
||||||
|
|
||||||
|
"""
|
||||||
|
base = Path(base_path).resolve()
|
||||||
|
relative = Path(relative_path)
|
||||||
|
|
||||||
|
if relative.is_absolute():
|
||||||
|
return relative
|
||||||
|
else:
|
||||||
|
return (base / relative).resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_path(path: PathLike) -> Path:
|
||||||
|
"""
|
||||||
|
Normalize a path, resolving any symbolic links and relative components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to normalize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Normalized absolute path
|
||||||
|
|
||||||
|
"""
|
||||||
|
return Path(path).resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def get_venv_path() -> Path | None:
|
||||||
|
"""
|
||||||
|
Get virtual environment path if active.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Virtual environment path if active, None otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
venv_path = os.environ.get("VIRTUAL_ENV")
|
||||||
|
if venv_path:
|
||||||
|
return Path(venv_path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_python_executable() -> Path:
|
||||||
|
"""
|
||||||
|
Get current Python executable path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path to the current Python executable
|
||||||
|
|
||||||
|
"""
|
||||||
|
return Path(sys.executable)
|
||||||
|
|
||||||
|
|
||||||
|
def find_executable(name: str) -> Path | None:
|
||||||
|
"""
|
||||||
|
Find executable in system PATH.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the executable to find
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path to executable if found, None otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Check if name already contains path
|
||||||
|
if os.path.dirname(name):
|
||||||
|
path = Path(name)
|
||||||
|
if path.exists() and os.access(path, os.X_OK):
|
||||||
|
return path
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Search in PATH
|
||||||
|
for path_dir in os.environ.get("PATH", "").split(os.pathsep):
|
||||||
|
if not path_dir:
|
||||||
|
continue
|
||||||
|
path = Path(path_dir) / name
|
||||||
|
if path.exists() and os.access(path, os.X_OK):
|
||||||
|
return path
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_script_binary_path() -> Path:
|
||||||
|
"""
|
||||||
|
Get path to script's binary directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Directory containing the script's binary
|
||||||
|
|
||||||
|
"""
|
||||||
|
return get_python_executable().parent
|
||||||
|
|
||||||
|
|
||||||
|
def get_binary_path(binary_name: str) -> Path | None:
|
||||||
|
"""
|
||||||
|
Find binary in PATH or virtual environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
binary_name: Name of the binary to find
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path to binary if found, None otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
# First check in virtual environment if active
|
||||||
|
venv_path = get_venv_path()
|
||||||
|
if venv_path:
|
||||||
|
venv_bin = venv_path / "bin" if os.name != "nt" else venv_path / "Scripts"
|
||||||
|
binary_path = venv_bin / binary_name
|
||||||
|
if binary_path.exists() and os.access(binary_path, os.X_OK):
|
||||||
|
return binary_path
|
||||||
|
|
||||||
|
# Fall back to system PATH
|
||||||
|
return find_executable(binary_name)
|
||||||
1
newsfragments/849.feature.rst
Normal file
1
newsfragments/849.feature.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add automatic peer dialing in bootstrap module using trio.Nursery.
|
||||||
2
newsfragments/886.bugfix.rst
Normal file
2
newsfragments/886.bugfix.rst
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
Fixed cross-platform path handling by replacing hardcoded OS-specific
|
||||||
|
paths with standardized utilities in core modules and examples.
|
||||||
255
scripts/audit_paths.py
Normal file
255
scripts/audit_paths.py
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Audit script to identify path handling issues in the py-libp2p codebase.
|
||||||
|
|
||||||
|
This script scans for patterns that should be migrated to use the new
|
||||||
|
cross-platform path utilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def scan_for_path_issues(directory: Path) -> dict[str, list[dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Scan for path handling issues in the codebase.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: Root directory to scan
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping issue types to lists of found issues
|
||||||
|
|
||||||
|
"""
|
||||||
|
issues = {
|
||||||
|
"hard_coded_slash": [],
|
||||||
|
"os_path_join": [],
|
||||||
|
"temp_hardcode": [],
|
||||||
|
"os_path_dirname": [],
|
||||||
|
"os_path_abspath": [],
|
||||||
|
"direct_path_concat": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Patterns to search for
|
||||||
|
patterns = {
|
||||||
|
"hard_coded_slash": r'["\'][^"\']*\/[^"\']*["\']',
|
||||||
|
"os_path_join": r"os\.path\.join\(",
|
||||||
|
"temp_hardcode": r'["\']\/tmp\/|["\']C:\\\\',
|
||||||
|
"os_path_dirname": r"os\.path\.dirname\(",
|
||||||
|
"os_path_abspath": r"os\.path\.abspath\(",
|
||||||
|
"direct_path_concat": r'["\'][^"\']*["\']\s*\+\s*["\'][^"\']*["\']',
|
||||||
|
}
|
||||||
|
|
||||||
|
# Files to exclude
|
||||||
|
exclude_patterns = [
|
||||||
|
r"__pycache__",
|
||||||
|
r"\.git",
|
||||||
|
r"\.pytest_cache",
|
||||||
|
r"\.mypy_cache",
|
||||||
|
r"\.ruff_cache",
|
||||||
|
r"env/",
|
||||||
|
r"venv/",
|
||||||
|
r"\.venv/",
|
||||||
|
]
|
||||||
|
|
||||||
|
for py_file in directory.rglob("*.py"):
|
||||||
|
# Skip excluded files
|
||||||
|
if any(re.search(pattern, str(py_file)) for pattern in exclude_patterns):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = py_file.read_text(encoding="utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
print(f"Warning: Could not read {py_file} (encoding issue)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for issue_type, pattern in patterns.items():
|
||||||
|
matches = re.finditer(pattern, content, re.MULTILINE)
|
||||||
|
for match in matches:
|
||||||
|
line_num = content[: match.start()].count("\n") + 1
|
||||||
|
line_content = content.split("\n")[line_num - 1].strip()
|
||||||
|
|
||||||
|
issues[issue_type].append(
|
||||||
|
{
|
||||||
|
"file": py_file,
|
||||||
|
"line": line_num,
|
||||||
|
"content": match.group(),
|
||||||
|
"full_line": line_content,
|
||||||
|
"relative_path": py_file.relative_to(directory),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return issues
|
||||||
|
|
||||||
|
|
||||||
|
def generate_migration_suggestions(issues: dict[str, list[dict[str, Any]]]) -> str:
|
||||||
|
"""
|
||||||
|
Generate migration suggestions for found issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
issues: Dictionary of found issues
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string with migration suggestions
|
||||||
|
|
||||||
|
"""
|
||||||
|
suggestions = []
|
||||||
|
|
||||||
|
for issue_type, issue_list in issues.items():
|
||||||
|
if not issue_list:
|
||||||
|
continue
|
||||||
|
|
||||||
|
suggestions.append(f"\n## {issue_type.replace('_', ' ').title()}")
|
||||||
|
suggestions.append(f"Found {len(issue_list)} instances:")
|
||||||
|
|
||||||
|
for issue in issue_list[:10]: # Show first 10 examples
|
||||||
|
suggestions.append(f"\n### {issue['relative_path']}:{issue['line']}")
|
||||||
|
suggestions.append("```python")
|
||||||
|
suggestions.append("# Current code:")
|
||||||
|
suggestions.append(f"{issue['full_line']}")
|
||||||
|
suggestions.append("```")
|
||||||
|
|
||||||
|
# Add migration suggestion based on issue type
|
||||||
|
if issue_type == "os_path_join":
|
||||||
|
suggestions.append("```python")
|
||||||
|
suggestions.append("# Suggested fix:")
|
||||||
|
suggestions.append("from libp2p.utils.paths import join_paths")
|
||||||
|
suggestions.append(
|
||||||
|
"# Replace os.path.join(a, b, c) with join_paths(a, b, c)"
|
||||||
|
)
|
||||||
|
suggestions.append("```")
|
||||||
|
elif issue_type == "temp_hardcode":
|
||||||
|
suggestions.append("```python")
|
||||||
|
suggestions.append("# Suggested fix:")
|
||||||
|
suggestions.append(
|
||||||
|
"from libp2p.utils.paths import get_temp_dir, create_temp_file"
|
||||||
|
)
|
||||||
|
temp_fix_msg = (
|
||||||
|
"# Replace hard-coded temp paths with get_temp_dir() or "
|
||||||
|
"create_temp_file()"
|
||||||
|
)
|
||||||
|
suggestions.append(temp_fix_msg)
|
||||||
|
suggestions.append("```")
|
||||||
|
elif issue_type == "os_path_dirname":
|
||||||
|
suggestions.append("```python")
|
||||||
|
suggestions.append("# Suggested fix:")
|
||||||
|
suggestions.append("from libp2p.utils.paths import get_script_dir")
|
||||||
|
script_dir_fix_msg = (
|
||||||
|
"# Replace os.path.dirname(os.path.abspath(__file__)) with "
|
||||||
|
"get_script_dir(__file__)"
|
||||||
|
)
|
||||||
|
suggestions.append(script_dir_fix_msg)
|
||||||
|
suggestions.append("```")
|
||||||
|
|
||||||
|
if len(issue_list) > 10:
|
||||||
|
suggestions.append(f"\n... and {len(issue_list) - 10} more instances")
|
||||||
|
|
||||||
|
return "\n".join(suggestions)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_summary_report(issues: dict[str, list[dict[str, Any]]]) -> str:
|
||||||
|
"""
|
||||||
|
Generate a summary report of all found issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
issues: Dictionary of found issues
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted summary report
|
||||||
|
|
||||||
|
"""
|
||||||
|
total_issues = sum(len(issue_list) for issue_list in issues.values())
|
||||||
|
|
||||||
|
report = [
|
||||||
|
"# Cross-Platform Path Handling Audit Report",
|
||||||
|
"",
|
||||||
|
"## Summary",
|
||||||
|
f"Total issues found: {total_issues}",
|
||||||
|
"",
|
||||||
|
"## Issue Breakdown:",
|
||||||
|
]
|
||||||
|
|
||||||
|
for issue_type, issue_list in issues.items():
|
||||||
|
if issue_list:
|
||||||
|
issue_title = issue_type.replace("_", " ").title()
|
||||||
|
instances_count = len(issue_list)
|
||||||
|
report.append(f"- **{issue_title}**: {instances_count} instances")
|
||||||
|
|
||||||
|
report.append("")
|
||||||
|
report.append("## Priority Matrix:")
|
||||||
|
report.append("")
|
||||||
|
report.append("| Priority | Issue Type | Risk Level | Impact |")
|
||||||
|
report.append("|----------|------------|------------|---------|")
|
||||||
|
|
||||||
|
priority_map = {
|
||||||
|
"temp_hardcode": (
|
||||||
|
"🔴 P0",
|
||||||
|
"HIGH",
|
||||||
|
"Core functionality fails on different platforms",
|
||||||
|
),
|
||||||
|
"os_path_join": ("🟡 P1", "MEDIUM", "Examples and utilities may break"),
|
||||||
|
"os_path_dirname": ("🟡 P1", "MEDIUM", "Script location detection issues"),
|
||||||
|
"hard_coded_slash": ("🟢 P2", "LOW", "Future-proofing and consistency"),
|
||||||
|
"os_path_abspath": ("🟢 P2", "LOW", "Path resolution consistency"),
|
||||||
|
"direct_path_concat": ("🟢 P2", "LOW", "String concatenation issues"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for issue_type, issue_list in issues.items():
|
||||||
|
if issue_list:
|
||||||
|
priority, risk, impact = priority_map.get(
|
||||||
|
issue_type, ("🟢 P2", "LOW", "General improvement")
|
||||||
|
)
|
||||||
|
issue_title = issue_type.replace("_", " ").title()
|
||||||
|
report.append(f"| {priority} | {issue_title} | {risk} | {impact} |")
|
||||||
|
|
||||||
|
return "\n".join(report)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to run the audit."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Audit py-libp2p codebase for path handling issues"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--directory",
|
||||||
|
default=".",
|
||||||
|
help="Directory to scan (default: current directory)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--output", help="Output file for detailed report")
|
||||||
|
parser.add_argument(
|
||||||
|
"--summary-only", action="store_true", help="Only show summary report"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
directory = Path(args.directory)
|
||||||
|
if not directory.exists():
|
||||||
|
print(f"Error: Directory {directory} does not exist")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
print("🔍 Scanning for path handling issues...")
|
||||||
|
issues = scan_for_path_issues(directory)
|
||||||
|
|
||||||
|
# Generate and display summary
|
||||||
|
summary = generate_summary_report(issues)
|
||||||
|
print(summary)
|
||||||
|
|
||||||
|
if not args.summary_only:
|
||||||
|
# Generate detailed suggestions
|
||||||
|
suggestions = generate_migration_suggestions(issues)
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
with open(args.output, "w", encoding="utf-8") as f:
|
||||||
|
f.write(summary)
|
||||||
|
f.write(suggestions)
|
||||||
|
print(f"\n📄 Detailed report saved to {args.output}")
|
||||||
|
else:
|
||||||
|
print(suggestions)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit(main())
|
||||||
@ -1,13 +0,0 @@
|
|||||||
from libp2p.security.noise.pb import noise_pb2 as noise_pb
|
|
||||||
|
|
||||||
|
|
||||||
def test_noise_extensions_serialization():
|
|
||||||
# Test NoiseExtensions
|
|
||||||
ext = noise_pb.NoiseExtensions()
|
|
||||||
ext.stream_muxers.append("/mplex/6.7.0")
|
|
||||||
ext.stream_muxers.append("/yamux/1.0.0")
|
|
||||||
|
|
||||||
# Serialize and deserialize
|
|
||||||
data = ext.SerializeToString()
|
|
||||||
ext2 = noise_pb.NoiseExtensions.FromString(data)
|
|
||||||
assert list(ext2.stream_muxers) == ["/mplex/6.7.0", "/yamux/1.0.0"]
|
|
||||||
@ -173,7 +173,8 @@ def noise_transport_factory(key_pair: KeyPair) -> ISecureTransport:
|
|||||||
return NoiseTransport(
|
return NoiseTransport(
|
||||||
libp2p_keypair=key_pair,
|
libp2p_keypair=key_pair,
|
||||||
noise_privkey=noise_static_key_factory(),
|
noise_privkey=noise_static_key_factory(),
|
||||||
# TODO: add early data
|
early_data=None,
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import pytest
|
|||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p.utils.logging import (
|
from libp2p.utils.logging import (
|
||||||
|
_current_handlers,
|
||||||
_current_listener,
|
_current_listener,
|
||||||
_listener_ready,
|
_listener_ready,
|
||||||
log_queue,
|
log_queue,
|
||||||
@ -24,13 +25,19 @@ from libp2p.utils.logging import (
|
|||||||
|
|
||||||
def _reset_logging():
|
def _reset_logging():
|
||||||
"""Reset all logging state."""
|
"""Reset all logging state."""
|
||||||
global _current_listener, _listener_ready
|
global _current_listener, _listener_ready, _current_handlers
|
||||||
|
|
||||||
# Stop existing listener if any
|
# Stop existing listener if any
|
||||||
if _current_listener is not None:
|
if _current_listener is not None:
|
||||||
_current_listener.stop()
|
_current_listener.stop()
|
||||||
_current_listener = None
|
_current_listener = None
|
||||||
|
|
||||||
|
# Close all file handlers to ensure proper cleanup on Windows
|
||||||
|
for handler in _current_handlers:
|
||||||
|
if isinstance(handler, logging.FileHandler):
|
||||||
|
handler.close()
|
||||||
|
_current_handlers.clear()
|
||||||
|
|
||||||
# Reset the event
|
# Reset the event
|
||||||
_listener_ready = threading.Event()
|
_listener_ready = threading.Event()
|
||||||
|
|
||||||
@ -174,6 +181,15 @@ async def test_custom_log_file(clean_env):
|
|||||||
if _current_listener is not None:
|
if _current_listener is not None:
|
||||||
_current_listener.stop()
|
_current_listener.stop()
|
||||||
|
|
||||||
|
# Give a moment for the listener to fully stop
|
||||||
|
await trio.sleep(0.05)
|
||||||
|
|
||||||
|
# Close all file handlers to release the file
|
||||||
|
for handler in _current_handlers:
|
||||||
|
if isinstance(handler, logging.FileHandler):
|
||||||
|
handler.flush() # Ensure all writes are flushed
|
||||||
|
handler.close()
|
||||||
|
|
||||||
# Check if the file exists and contains our message
|
# Check if the file exists and contains our message
|
||||||
assert log_file.exists()
|
assert log_file.exists()
|
||||||
content = log_file.read_text()
|
content = log_file.read_text()
|
||||||
@ -185,16 +201,15 @@ async def test_default_log_file(clean_env):
|
|||||||
"""Test logging to the default file path."""
|
"""Test logging to the default file path."""
|
||||||
os.environ["LIBP2P_DEBUG"] = "INFO"
|
os.environ["LIBP2P_DEBUG"] = "INFO"
|
||||||
|
|
||||||
with patch("libp2p.utils.logging.datetime") as mock_datetime:
|
with patch("libp2p.utils.paths.create_temp_file") as mock_create_temp:
|
||||||
# Mock the timestamp to have a predictable filename
|
# Mock the temp file creation to return a predictable path
|
||||||
mock_datetime.now.return_value.strftime.return_value = "20240101_120000"
|
mock_temp_file = (
|
||||||
|
Path(tempfile.gettempdir()) / "test_py-libp2p_20240101_120000.log"
|
||||||
|
)
|
||||||
|
mock_create_temp.return_value = mock_temp_file
|
||||||
|
|
||||||
# Remove the log file if it exists
|
# Remove the log file if it exists
|
||||||
if os.name == "nt": # Windows
|
mock_temp_file.unlink(missing_ok=True)
|
||||||
log_file = Path("C:/Windows/Temp/20240101_120000_py-libp2p.log")
|
|
||||||
else: # Unix-like
|
|
||||||
log_file = Path("/tmp/20240101_120000_py-libp2p.log")
|
|
||||||
log_file.unlink(missing_ok=True)
|
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
@ -211,9 +226,18 @@ async def test_default_log_file(clean_env):
|
|||||||
if _current_listener is not None:
|
if _current_listener is not None:
|
||||||
_current_listener.stop()
|
_current_listener.stop()
|
||||||
|
|
||||||
# Check the default log file
|
# Give a moment for the listener to fully stop
|
||||||
if log_file.exists(): # Only check content if we have write permission
|
await trio.sleep(0.05)
|
||||||
content = log_file.read_text()
|
|
||||||
|
# Close all file handlers to release the file
|
||||||
|
for handler in _current_handlers:
|
||||||
|
if isinstance(handler, logging.FileHandler):
|
||||||
|
handler.flush() # Ensure all writes are flushed
|
||||||
|
handler.close()
|
||||||
|
|
||||||
|
# Check the mocked temp file
|
||||||
|
if mock_temp_file.exists():
|
||||||
|
content = mock_temp_file.read_text()
|
||||||
assert "Test message" in content
|
assert "Test message" in content
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
290
tests/utils/test_paths.py
Normal file
290
tests/utils/test_paths.py
Normal file
@ -0,0 +1,290 @@
|
|||||||
|
"""
|
||||||
|
Tests for cross-platform path utilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.utils.paths import (
|
||||||
|
create_temp_file,
|
||||||
|
ensure_dir_exists,
|
||||||
|
find_executable,
|
||||||
|
get_binary_path,
|
||||||
|
get_config_dir,
|
||||||
|
get_project_root,
|
||||||
|
get_python_executable,
|
||||||
|
get_script_binary_path,
|
||||||
|
get_script_dir,
|
||||||
|
get_temp_dir,
|
||||||
|
get_venv_path,
|
||||||
|
join_paths,
|
||||||
|
normalize_path,
|
||||||
|
resolve_relative_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPathUtilities:
|
||||||
|
"""Test cross-platform path utilities."""
|
||||||
|
|
||||||
|
def test_get_temp_dir(self):
|
||||||
|
"""Test that temp directory is accessible and exists."""
|
||||||
|
temp_dir = get_temp_dir()
|
||||||
|
assert isinstance(temp_dir, Path)
|
||||||
|
assert temp_dir.exists()
|
||||||
|
assert temp_dir.is_dir()
|
||||||
|
# Should match system temp directory
|
||||||
|
assert temp_dir == Path(tempfile.gettempdir())
|
||||||
|
|
||||||
|
def test_get_project_root(self):
|
||||||
|
"""Test that project root is correctly determined."""
|
||||||
|
project_root = get_project_root()
|
||||||
|
assert isinstance(project_root, Path)
|
||||||
|
assert project_root.exists()
|
||||||
|
# Should contain pyproject.toml
|
||||||
|
assert (project_root / "pyproject.toml").exists()
|
||||||
|
# Should contain libp2p directory
|
||||||
|
assert (project_root / "libp2p").exists()
|
||||||
|
|
||||||
|
def test_join_paths(self):
|
||||||
|
"""Test cross-platform path joining."""
|
||||||
|
# Test with strings
|
||||||
|
result = join_paths("a", "b", "c")
|
||||||
|
expected = Path("a") / "b" / "c"
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
# Test with mixed types
|
||||||
|
result = join_paths("a", Path("b"), "c")
|
||||||
|
expected = Path("a") / "b" / "c"
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
# Test with absolute path
|
||||||
|
result = join_paths("/absolute", "path")
|
||||||
|
expected = Path("/absolute") / "path"
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
def test_ensure_dir_exists(self, tmp_path):
|
||||||
|
"""Test directory creation and existence checking."""
|
||||||
|
# Test creating new directory
|
||||||
|
new_dir = tmp_path / "new_dir"
|
||||||
|
result = ensure_dir_exists(new_dir)
|
||||||
|
assert result == new_dir
|
||||||
|
assert new_dir.exists()
|
||||||
|
assert new_dir.is_dir()
|
||||||
|
|
||||||
|
# Test creating nested directory
|
||||||
|
nested_dir = tmp_path / "parent" / "child" / "grandchild"
|
||||||
|
result = ensure_dir_exists(nested_dir)
|
||||||
|
assert result == nested_dir
|
||||||
|
assert nested_dir.exists()
|
||||||
|
assert nested_dir.is_dir()
|
||||||
|
|
||||||
|
# Test with existing directory
|
||||||
|
result = ensure_dir_exists(new_dir)
|
||||||
|
assert result == new_dir
|
||||||
|
assert new_dir.exists()
|
||||||
|
|
||||||
|
def test_get_config_dir(self):
|
||||||
|
"""Test platform-specific config directory."""
|
||||||
|
config_dir = get_config_dir()
|
||||||
|
assert isinstance(config_dir, Path)
|
||||||
|
|
||||||
|
if os.name == "nt": # Windows
|
||||||
|
# Should be in AppData/Roaming or user home
|
||||||
|
assert "AppData" in str(config_dir) or "py-libp2p" in str(config_dir)
|
||||||
|
else: # Unix-like
|
||||||
|
# Should be in ~/.config
|
||||||
|
assert ".config" in str(config_dir)
|
||||||
|
assert "py-libp2p" in str(config_dir)
|
||||||
|
|
||||||
|
def test_get_script_dir(self):
|
||||||
|
"""Test script directory detection."""
|
||||||
|
# Test with current file
|
||||||
|
script_dir = get_script_dir(__file__)
|
||||||
|
assert isinstance(script_dir, Path)
|
||||||
|
assert script_dir.exists()
|
||||||
|
assert script_dir.is_dir()
|
||||||
|
# Should contain this test file
|
||||||
|
assert (script_dir / "test_paths.py").exists()
|
||||||
|
|
||||||
|
def test_create_temp_file(self):
|
||||||
|
"""Test temporary file creation."""
|
||||||
|
temp_file = create_temp_file()
|
||||||
|
assert isinstance(temp_file, Path)
|
||||||
|
assert temp_file.parent == get_temp_dir()
|
||||||
|
assert temp_file.name.startswith("py-libp2p_")
|
||||||
|
assert temp_file.name.endswith(".log")
|
||||||
|
|
||||||
|
# Test with custom prefix and suffix
|
||||||
|
temp_file = create_temp_file(prefix="test_", suffix=".txt")
|
||||||
|
assert temp_file.name.startswith("test_")
|
||||||
|
assert temp_file.name.endswith(".txt")
|
||||||
|
|
||||||
|
def test_resolve_relative_path(self, tmp_path):
|
||||||
|
"""Test relative path resolution."""
|
||||||
|
base_path = tmp_path / "base"
|
||||||
|
base_path.mkdir()
|
||||||
|
|
||||||
|
# Test relative path
|
||||||
|
relative_path = "subdir/file.txt"
|
||||||
|
result = resolve_relative_path(base_path, relative_path)
|
||||||
|
expected = (base_path / "subdir" / "file.txt").resolve()
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
# Test absolute path (platform-agnostic)
|
||||||
|
if os.name == "nt": # Windows
|
||||||
|
absolute_path = "C:\\absolute\\path"
|
||||||
|
else: # Unix-like
|
||||||
|
absolute_path = "/absolute/path"
|
||||||
|
result = resolve_relative_path(base_path, absolute_path)
|
||||||
|
assert result == Path(absolute_path)
|
||||||
|
|
||||||
|
def test_normalize_path(self, tmp_path):
|
||||||
|
"""Test path normalization."""
|
||||||
|
# Test with relative path
|
||||||
|
relative_path = tmp_path / ".." / "normalize_test"
|
||||||
|
result = normalize_path(relative_path)
|
||||||
|
assert result.is_absolute()
|
||||||
|
assert "normalize_test" in str(result)
|
||||||
|
|
||||||
|
# Test with absolute path
|
||||||
|
absolute_path = tmp_path / "test_file"
|
||||||
|
result = normalize_path(absolute_path)
|
||||||
|
assert result.is_absolute()
|
||||||
|
assert result == absolute_path.resolve()
|
||||||
|
|
||||||
|
def test_get_venv_path(self, monkeypatch):
|
||||||
|
"""Test virtual environment path detection."""
|
||||||
|
# Test when no virtual environment is active
|
||||||
|
# Temporarily clear VIRTUAL_ENV to test the "no venv" case
|
||||||
|
monkeypatch.delenv("VIRTUAL_ENV", raising=False)
|
||||||
|
result = get_venv_path()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Test when virtual environment is active
|
||||||
|
test_venv_path = "/path/to/venv"
|
||||||
|
monkeypatch.setenv("VIRTUAL_ENV", test_venv_path)
|
||||||
|
result = get_venv_path()
|
||||||
|
assert result == Path(test_venv_path)
|
||||||
|
|
||||||
|
def test_get_python_executable(self):
|
||||||
|
"""Test Python executable path detection."""
|
||||||
|
result = get_python_executable()
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert result.exists()
|
||||||
|
assert result.name.startswith("python")
|
||||||
|
|
||||||
|
def test_find_executable(self):
|
||||||
|
"""Test executable finding in PATH."""
|
||||||
|
# Test with non-existent executable
|
||||||
|
result = find_executable("nonexistent_executable")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Test with existing executable (python should be available)
|
||||||
|
result = find_executable("python")
|
||||||
|
if result:
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert result.exists()
|
||||||
|
|
||||||
|
def test_get_script_binary_path(self):
|
||||||
|
"""Test script binary path detection."""
|
||||||
|
result = get_script_binary_path()
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert result.exists()
|
||||||
|
assert result.is_dir()
|
||||||
|
|
||||||
|
def test_get_binary_path(self, monkeypatch):
|
||||||
|
"""Test binary path resolution with virtual environment."""
|
||||||
|
# Test when no virtual environment is active
|
||||||
|
result = get_binary_path("python")
|
||||||
|
if result:
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert result.exists()
|
||||||
|
|
||||||
|
# Test when virtual environment is active
|
||||||
|
test_venv_path = "/path/to/venv"
|
||||||
|
monkeypatch.setenv("VIRTUAL_ENV", test_venv_path)
|
||||||
|
# This test is more complex as it depends on the actual venv structure
|
||||||
|
# We'll just verify the function doesn't crash
|
||||||
|
result = get_binary_path("python")
|
||||||
|
# Result can be None if binary not found in venv
|
||||||
|
if result:
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrossPlatformCompatibility:
|
||||||
|
"""Test cross-platform compatibility."""
|
||||||
|
|
||||||
|
def test_config_dir_platform_specific_windows(self, monkeypatch):
|
||||||
|
"""Test config directory respects Windows conventions."""
|
||||||
|
import platform
|
||||||
|
|
||||||
|
# Only run this test on Windows systems
|
||||||
|
if platform.system() != "Windows":
|
||||||
|
pytest.skip("This test only runs on Windows systems")
|
||||||
|
|
||||||
|
monkeypatch.setattr("os.name", "nt")
|
||||||
|
monkeypatch.setenv("APPDATA", "C:\\Users\\Test\\AppData\\Roaming")
|
||||||
|
config_dir = get_config_dir()
|
||||||
|
assert "AppData" in str(config_dir)
|
||||||
|
assert "py-libp2p" in str(config_dir)
|
||||||
|
|
||||||
|
def test_path_separators_consistent(self):
|
||||||
|
"""Test that path separators are handled consistently."""
|
||||||
|
# Test that join_paths uses platform-appropriate separators
|
||||||
|
result = join_paths("dir1", "dir2", "file.txt")
|
||||||
|
expected = Path("dir1") / "dir2" / "file.txt"
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
# Test that the result uses correct separators for the platform
|
||||||
|
if os.name == "nt": # Windows
|
||||||
|
assert "\\" in str(result) or "/" in str(result)
|
||||||
|
else: # Unix-like
|
||||||
|
assert "/" in str(result)
|
||||||
|
|
||||||
|
def test_temp_file_uniqueness(self):
|
||||||
|
"""Test that temporary files have unique names."""
|
||||||
|
files = set()
|
||||||
|
for _ in range(10):
|
||||||
|
temp_file = create_temp_file()
|
||||||
|
assert temp_file not in files
|
||||||
|
files.add(temp_file)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackwardCompatibility:
|
||||||
|
"""Test backward compatibility with existing code patterns."""
|
||||||
|
|
||||||
|
def test_path_operations_equivalent(self):
|
||||||
|
"""Test that new path operations are equivalent to old os.path operations."""
|
||||||
|
# Test join_paths vs os.path.join
|
||||||
|
parts = ["a", "b", "c"]
|
||||||
|
new_result = join_paths(*parts)
|
||||||
|
old_result = Path(os.path.join(*parts))
|
||||||
|
assert new_result == old_result
|
||||||
|
|
||||||
|
# Test get_script_dir vs os.path.dirname(os.path.abspath(__file__))
|
||||||
|
new_script_dir = get_script_dir(__file__)
|
||||||
|
old_script_dir = Path(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
assert new_script_dir == old_script_dir
|
||||||
|
|
||||||
|
def test_existing_functionality_preserved(self):
|
||||||
|
"""Ensure no existing functionality is broken."""
|
||||||
|
# Test that all functions return Path objects
|
||||||
|
assert isinstance(get_temp_dir(), Path)
|
||||||
|
assert isinstance(get_project_root(), Path)
|
||||||
|
assert isinstance(join_paths("a", "b"), Path)
|
||||||
|
assert isinstance(ensure_dir_exists(tempfile.gettempdir()), Path)
|
||||||
|
assert isinstance(get_config_dir(), Path)
|
||||||
|
assert isinstance(get_script_dir(__file__), Path)
|
||||||
|
assert isinstance(create_temp_file(), Path)
|
||||||
|
assert isinstance(resolve_relative_path(".", "test"), Path)
|
||||||
|
assert isinstance(normalize_path("."), Path)
|
||||||
|
assert isinstance(get_python_executable(), Path)
|
||||||
|
assert isinstance(get_script_binary_path(), Path)
|
||||||
|
|
||||||
|
# Test optional return types
|
||||||
|
venv_path = get_venv_path()
|
||||||
|
if venv_path is not None:
|
||||||
|
assert isinstance(venv_path, Path)
|
||||||
Reference in New Issue
Block a user