Merge branch 'main' into add-ws-transport

This commit is contained in:
yashksaini-coder
2025-09-04 15:19:11 +05:30
committed by GitHub
61 changed files with 3936 additions and 634 deletions

View File

@ -1,3 +1,10 @@
from unittest.mock import (
AsyncMock,
MagicMock,
)
import pytest
from libp2p import (
new_swarm,
)
@ -10,6 +17,9 @@ from libp2p.host.basic_host import (
from libp2p.host.defaults import (
get_default_protocols,
)
from libp2p.host.exceptions import (
StreamFailure,
)
def test_default_protocols():
@ -22,3 +32,30 @@ def test_default_protocols():
# NOTE: comparing keys for equality as handlers may be closures that do not compare
# in the way this test is concerned with
assert handlers.keys() == get_default_protocols(host).keys()
@pytest.mark.trio
async def test_swarm_stream_handler_no_protocol_selected(monkeypatch):
key_pair = create_new_key_pair()
swarm = new_swarm(key_pair)
host = BasicHost(swarm)
# Create a mock net_stream
net_stream = MagicMock()
net_stream.reset = AsyncMock()
net_stream.muxed_conn.peer_id = "peer-test"
# Monkeypatch negotiate to simulate "no protocol selected"
async def fake_negotiate(comm, timeout):
return None, None
monkeypatch.setattr(host.multiselect, "negotiate", fake_negotiate)
# Now run the handler and expect StreamFailure
with pytest.raises(
StreamFailure, match="Failed to negotiate protocol: no protocol selected"
):
await host._swarm_stream_handler(net_stream)
# Ensure reset was called since negotiation failed
net_stream.reset.assert_awaited()

View File

@ -164,8 +164,8 @@ async def test_live_peers_unexpected_drop(security_protocol):
assert peer_a_id in host_b.get_live_peers()
# Simulate unexpected connection drop by directly closing the connection
conn = host_a.get_network().connections[peer_b_id]
await conn.muxed_conn.close()
conns = host_a.get_network().connections[peer_b_id]
await conns[0].muxed_conn.close()
# Allow for connection cleanup
await trio.sleep(0.1)

View File

@ -9,11 +9,15 @@ This module tests core functionality of the Kademlia DHT including:
import hashlib
import logging
import os
from unittest.mock import patch
import uuid
import pytest
import multiaddr
import trio
from libp2p.crypto.rsa import create_new_key_pair
from libp2p.kad_dht.kad_dht import (
DHTMode,
KadDHT,
@ -21,9 +25,13 @@ from libp2p.kad_dht.kad_dht import (
from libp2p.kad_dht.utils import (
create_key_from_binary,
)
from libp2p.peer.envelope import Envelope, seal_record
from libp2p.peer.id import ID
from libp2p.peer.peer_record import PeerRecord
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import create_signed_peer_record
from libp2p.tools.async_service import (
background_trio_service,
)
@ -76,10 +84,52 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]):
"""Test that nodes can find each other in the DHT."""
dht_a, dht_b = dht_pair
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
# so both the nodes will have records of each other before the next FIND_NODE
# req is sent
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
assert isinstance(envelope_a, Envelope)
assert isinstance(envelope_b, Envelope)
record_a = envelope_a.record()
record_b = envelope_b.record()
# Node A should be able to find Node B
with trio.fail_after(TEST_TIMEOUT):
found_info = await dht_a.find_peer(dht_b.host.get_id())
# Verifies if the senderRecord in the FIND_NODE request is correctly processed
assert isinstance(
dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()), Envelope
)
# Verifies if the senderRecord in the FIND_NODE response is correctly processed
assert isinstance(
dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope
)
# These are the records that were sent between the peers during the FIND_NODE req
envelope_a_find_peer = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_find_peer = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_find_peer, Envelope)
assert isinstance(envelope_b_find_peer, Envelope)
record_a_find_peer = envelope_a_find_peer.record()
record_b_find_peer = envelope_b_find_peer.record()
# This proves that both the records are same, and a latest cached signed record
# was passed between the peers during FIND_NODE execution, which proves the
# signed-record transfer/re-issuing works correctly in FIND_NODE executions.
assert record_a.seq == record_a_find_peer.seq
assert record_b.seq == record_b_find_peer.seq
# Verify that the found peer has the correct peer ID
assert found_info is not None, "Failed to find the target peer"
assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID"
@ -104,14 +154,44 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
await dht_a.routing_table.add_peer(peer_b_info)
print("Routing table of a has ", dht_a.routing_table.get_peer_ids())
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
# so both the nodes will have records of each other before PUT_VALUE req is sent
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
assert isinstance(envelope_a, Envelope)
assert isinstance(envelope_b, Envelope)
record_a = envelope_a.record()
record_b = envelope_b.record()
# Store the value using the first node (this will also store locally)
with trio.fail_after(TEST_TIMEOUT):
await dht_a.put_value(key, value)
# These are the records that were sent between the peers during the PUT_VALUE req
envelope_a_put_value = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_put_value = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_put_value, Envelope)
assert isinstance(envelope_b_put_value, Envelope)
record_a_put_value = envelope_a_put_value.record()
record_b_put_value = envelope_b_put_value.record()
# This proves that both the records are same, and a latest cached signed record
# was passed between the peers during PUT_VALUE execution, which proves the
# signed-record transfer/re-issuing works correctly in PUT_VALUE executions.
assert record_a.seq == record_a_put_value.seq
assert record_b.seq == record_b_put_value.seq
# # Log debugging information
logger.debug("Put value with key %s...", key.hex()[:10])
logger.debug("Node A value store: %s", dht_a.value_store.store)
print("hello test")
# # Allow more time for the value to propagate
await trio.sleep(0.5)
@ -126,6 +206,26 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
print("the value stored in node b is", dht_b.get_value_store_size())
logger.debug("Retrieved value: %s", retrieved_value)
# These are the records that were sent between the peers during the PUT_VALUE req
envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_get_value, Envelope)
assert isinstance(envelope_b_get_value, Envelope)
record_a_get_value = envelope_a_get_value.record()
record_b_get_value = envelope_b_get_value.record()
# This proves that there was no record exchange between the nodes during GET_VALUE
# execution, as dht_b already had the key/value pair stored locally after the
# PUT_VALUE execution.
assert record_a_get_value.seq == record_a_put_value.seq
assert record_b_get_value.seq == record_b_put_value.seq
# Verify that the retrieved value matches the original
assert retrieved_value == value, "Retrieved value does not match the stored value"
@ -142,11 +242,44 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
# Store content on the first node
dht_a.value_store.put(content_id, content)
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
# so both the nodes will have records of each other before PUT_VALUE req is sent
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
assert isinstance(envelope_a, Envelope)
assert isinstance(envelope_b, Envelope)
record_a = envelope_a.record()
record_b = envelope_b.record()
# Advertise the first node as a provider
with trio.fail_after(TEST_TIMEOUT):
success = await dht_a.provide(content_id)
assert success, "Failed to advertise as provider"
# These are the records that were sent between the peers during
# the ADD_PROVIDER req
envelope_a_add_prov = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_add_prov = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_add_prov, Envelope)
assert isinstance(envelope_b_add_prov, Envelope)
record_a_add_prov = envelope_a_add_prov.record()
record_b_add_prov = envelope_b_add_prov.record()
# This proves that both the records are same, the latest cached signed record
# was passed between the peers during ADD_PROVIDER execution, which proves the
# signed-record transfer/re-issuing of the latest record works correctly in
# ADD_PROVIDER executions.
assert record_a.seq == record_a_add_prov.seq
assert record_b.seq == record_b_add_prov.seq
# Allow time for the provider record to propagate
await trio.sleep(0.1)
@ -154,6 +287,26 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
with trio.fail_after(TEST_TIMEOUT):
providers = await dht_b.find_providers(content_id)
# These are the records in each peer after the find_provider execution
envelope_a_find_prov = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_find_prov = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_find_prov, Envelope)
assert isinstance(envelope_b_find_prov, Envelope)
record_a_find_prov = envelope_a_find_prov.record()
record_b_find_prov = envelope_b_find_prov.record()
# This proves that both the records are same, as the dht_b already
# has the provider record for the content_id, after the ADD_PROVIDER
# advertisement by dht_a
assert record_a_find_prov.seq == record_a_add_prov.seq
assert record_b_find_prov.seq == record_b_add_prov.seq
# Verify that we found the first node as a provider
assert providers, "No providers found"
assert any(p.peer_id == dht_a.local_peer_id for p in providers), (
@ -166,3 +319,143 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
assert retrieved_value == content, (
"Retrieved content does not match the original"
)
# These are the record state of each peer aftet the GET_VALUE execution
envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record(
dht_b.host.get_id()
)
envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record(
dht_a.host.get_id()
)
assert isinstance(envelope_a_get_value, Envelope)
assert isinstance(envelope_b_get_value, Envelope)
record_a_get_value = envelope_a_get_value.record()
record_b_get_value = envelope_b_get_value.record()
# This proves that both the records are same, meaning that the latest cached
# signed-record tranfer happened during the GET_VALUE execution by dht_b,
# which means the signed-record transfer/re-issuing works correctly
# in GET_VALUE executions.
assert record_a_find_prov.seq == record_a_get_value.seq
assert record_b_find_prov.seq == record_b_get_value.seq
# Create a new provider record in dht_a
provider_key_pair = create_new_key_pair()
provider_peer_id = ID.from_pubkey(provider_key_pair.public_key)
provider_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
provider_peer_info = PeerInfo(peer_id=provider_peer_id, addrs=[provider_addr])
# Generate a random content ID
content_2 = f"random-content-{uuid.uuid4()}".encode()
content_id_2 = hashlib.sha256(content_2).digest()
provider_signed_envelope = create_signed_peer_record(
provider_peer_id, [provider_addr], provider_key_pair.private_key
)
assert (
dht_a.host.get_peerstore().consume_peer_record(provider_signed_envelope, 7200)
is True
)
# Store this provider record in dht_a
dht_a.provider_store.add_provider(content_id_2, provider_peer_info)
# Fetch the provider-record via peer-discovery at dht_b's end
peerinfo = await dht_b.provider_store.find_providers(content_id_2)
assert len(peerinfo) == 1
assert peerinfo[0].peer_id == provider_peer_id
provider_envelope = dht_b.host.get_peerstore().get_peer_record(provider_peer_id)
# This proves that the signed-envelope of provider is consumed on dht_b's end
assert provider_envelope is not None
assert (
provider_signed_envelope.marshal_envelope()
== provider_envelope.marshal_envelope()
)
@pytest.mark.trio
async def test_reissue_when_listen_addrs_change(dht_pair: tuple[KadDHT, KadDHT]):
dht_a, dht_b = dht_pair
# Warm-up: A stores B's current record
with trio.fail_after(10):
await dht_a.find_peer(dht_b.host.get_id())
env0 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
assert isinstance(env0, Envelope)
seq0 = env0.record().seq
# Simulate B's listen addrs changing (different port)
new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
# Patch just for the duration we force B to respond:
with patch.object(dht_b.host, "get_addrs", return_value=[new_addr]):
# Force B to send a response (which should include a fresh SPR)
with trio.fail_after(10):
await dht_a.peer_routing._query_peer_for_closest(
dht_b.host.get_id(), os.urandom(32)
)
# A should now hold B's new record with a bumped seq
env1 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
assert isinstance(env1, Envelope)
seq1 = env1.record().seq
# This proves that upon the change in listen_addrs, we issue new records
assert seq1 > seq0, f"Expected seq to bump after addr change, got {seq0} -> {seq1}"
@pytest.mark.trio
async def test_dht_req_fail_with_invalid_record_transfer(
dht_pair: tuple[KadDHT, KadDHT],
):
"""
Testing showing failure of storing and retrieving values in the DHT,
if invalid signed-records are sent.
"""
dht_a, dht_b = dht_pair
peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs())
# Generate a random key and value
key = create_key_from_binary(b"test-key")
value = b"test-value"
# First add the value directly to node A's store to verify storage works
dht_a.value_store.put(key, value)
local_value = dht_a.value_store.get(key)
assert local_value == value, "Local value storage failed"
await dht_a.routing_table.add_peer(peer_b_info)
# Corrupt dht_a's local peer_record
envelope = dht_a.host.get_peerstore().get_local_record()
if envelope is not None:
true_record = envelope.record()
key_pair = create_new_key_pair()
if envelope is not None:
envelope.public_key = key_pair.public_key
dht_a.host.get_peerstore().set_local_record(envelope)
await dht_a.put_value(key, value)
retrieved_value = dht_b.value_store.get(key)
# This proves that DHT_B rejected DHT_A PUT_RECORD req upon receiving
# the corrupted invalid record
assert retrieved_value is None
# Create a corrupt envelope with correct signature but false peer_id
false_record = PeerRecord(ID.from_pubkey(key_pair.public_key), true_record.addrs)
false_envelope = seal_record(false_record, dht_a.host.get_private_key())
dht_a.host.get_peerstore().set_local_record(false_envelope)
await dht_a.put_value(key, value)
retrieved_value = dht_b.value_store.get(key)
# This proves that DHT_B rejected DHT_A PUT_RECORD req upon receving
# the record with a different peer_id regardless of a valid signature
assert retrieved_value is None

View File

@ -57,7 +57,10 @@ class TestPeerRouting:
def mock_host(self):
"""Create a mock host for testing."""
host = Mock()
host.get_id.return_value = create_valid_peer_id("local")
key_pair = create_new_key_pair()
host.get_id.return_value = ID.from_pubkey(key_pair.public_key)
host.get_public_key.return_value = key_pair.public_key
host.get_private_key.return_value = key_pair.private_key
host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
host.get_peerstore.return_value = Mock()
host.new_stream = AsyncMock()

View File

@ -0,0 +1,325 @@
import time
from typing import cast
from unittest.mock import Mock
import pytest
from multiaddr import Multiaddr
import trio
from libp2p.abc import INetConn, INetStream
from libp2p.network.exceptions import SwarmException
from libp2p.network.swarm import (
ConnectionConfig,
RetryConfig,
Swarm,
)
from libp2p.peer.id import ID
class MockConnection(INetConn):
"""Mock connection for testing."""
def __init__(self, peer_id: ID, is_closed: bool = False):
self.peer_id = peer_id
self._is_closed = is_closed
self.streams = set() # Track streams properly
# Mock the muxed_conn attribute that Swarm expects
self.muxed_conn = Mock()
self.muxed_conn.peer_id = peer_id
# Required by INetConn interface
self.event_started = trio.Event()
async def close(self):
self._is_closed = True
@property
def is_closed(self) -> bool:
return self._is_closed
async def new_stream(self) -> INetStream:
# Create a mock stream and add it to the connection's stream set
mock_stream = Mock(spec=INetStream)
self.streams.add(mock_stream)
return mock_stream
def get_streams(self) -> tuple[INetStream, ...]:
"""Return all streams associated with this connection."""
return tuple(self.streams)
def get_transport_addresses(self) -> list[Multiaddr]:
"""Mock implementation of get_transport_addresses."""
return []
class MockNetStream(INetStream):
"""Mock network stream for testing."""
def __init__(self, peer_id: ID):
self.peer_id = peer_id
@pytest.mark.trio
async def test_retry_config_defaults():
"""Test RetryConfig default values."""
config = RetryConfig()
assert config.max_retries == 3
assert config.initial_delay == 0.1
assert config.max_delay == 30.0
assert config.backoff_multiplier == 2.0
assert config.jitter_factor == 0.1
@pytest.mark.trio
async def test_connection_config_defaults():
"""Test ConnectionConfig default values."""
config = ConnectionConfig()
assert config.max_connections_per_peer == 3
assert config.connection_timeout == 30.0
assert config.load_balancing_strategy == "round_robin"
@pytest.mark.trio
async def test_enhanced_swarm_constructor():
"""Test enhanced Swarm constructor with new configuration."""
# Create mock dependencies
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
# Test with default config
swarm = Swarm(peer_id, peerstore, upgrader, transport)
assert swarm.retry_config.max_retries == 3
assert swarm.connection_config.max_connections_per_peer == 3
assert isinstance(swarm.connections, dict)
# Test with custom config
custom_retry = RetryConfig(max_retries=5, initial_delay=0.5)
custom_conn = ConnectionConfig(max_connections_per_peer=5)
swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn)
assert swarm.retry_config.max_retries == 5
assert swarm.retry_config.initial_delay == 0.5
assert swarm.connection_config.max_connections_per_peer == 5
@pytest.mark.trio
async def test_swarm_backoff_calculation():
"""Test exponential backoff calculation with jitter."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
retry_config = RetryConfig(
initial_delay=0.1, max_delay=1.0, backoff_multiplier=2.0, jitter_factor=0.1
)
swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config)
# Test backoff calculation
delay1 = swarm._calculate_backoff_delay(0)
delay2 = swarm._calculate_backoff_delay(1)
delay3 = swarm._calculate_backoff_delay(2)
# Should increase exponentially
assert delay2 > delay1
assert delay3 > delay2
# Should respect max delay
assert delay1 <= 1.0
assert delay2 <= 1.0
assert delay3 <= 1.0
# Should have jitter
assert delay1 != 0.1 # Should have jitter added
@pytest.mark.trio
async def test_swarm_retry_logic():
"""Test retry logic in dial operations."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
# Configure for fast testing
retry_config = RetryConfig(
max_retries=2,
initial_delay=0.01, # Very short for testing
max_delay=0.1,
)
swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config)
# Mock the single attempt method to fail twice then succeed
attempt_count = [0]
async def mock_single_attempt(addr, peer_id):
attempt_count[0] += 1
if attempt_count[0] < 3:
raise SwarmException(f"Attempt {attempt_count[0]} failed")
return MockConnection(peer_id)
swarm._dial_addr_single_attempt = mock_single_attempt
# Test retry logic
start_time = time.time()
result = await swarm._dial_with_retry(Mock(spec=Multiaddr), peer_id)
end_time = time.time()
# Should have succeeded after 3 attempts
assert attempt_count[0] == 3
assert isinstance(result, MockConnection)
assert end_time - start_time > 0.01 # Should have some delay
@pytest.mark.trio
async def test_swarm_load_balancing_strategies():
"""Test load balancing strategies."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
swarm = Swarm(peer_id, peerstore, upgrader, transport)
# Create mock connections with different stream counts
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
conn3 = MockConnection(peer_id)
# Add some streams to simulate load
await conn1.new_stream()
await conn1.new_stream()
await conn2.new_stream()
connections = [conn1, conn2, conn3]
# Test round-robin strategy
swarm.connection_config.load_balancing_strategy = "round_robin"
# Cast to satisfy type checker
connections_cast = cast("list[INetConn]", connections)
selected1 = swarm._select_connection(connections_cast, peer_id)
selected2 = swarm._select_connection(connections_cast, peer_id)
selected3 = swarm._select_connection(connections_cast, peer_id)
# Should cycle through connections
assert selected1 in connections
assert selected2 in connections
assert selected3 in connections
# Test least loaded strategy
swarm.connection_config.load_balancing_strategy = "least_loaded"
least_loaded = swarm._select_connection(connections_cast, peer_id)
# conn3 has 0 streams, conn2 has 1 stream, conn1 has 2 streams
# So conn3 should be selected as least loaded
assert least_loaded == conn3
# Test default strategy (first connection)
swarm.connection_config.load_balancing_strategy = "unknown"
default_selected = swarm._select_connection(connections_cast, peer_id)
assert default_selected == conn1
@pytest.mark.trio
async def test_swarm_multiple_connections_api():
"""Test the new multiple connections API methods."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
swarm = Swarm(peer_id, peerstore, upgrader, transport)
# Test empty connections
assert swarm.get_connections() == []
assert swarm.get_connections(peer_id) == []
assert swarm.get_connection(peer_id) is None
assert swarm.get_connections_map() == {}
# Add some connections
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
swarm.connections[peer_id] = [conn1, conn2]
# Test get_connections with peer_id
peer_connections = swarm.get_connections(peer_id)
assert len(peer_connections) == 2
assert conn1 in peer_connections
assert conn2 in peer_connections
# Test get_connections without peer_id (all connections)
all_connections = swarm.get_connections()
assert len(all_connections) == 2
assert conn1 in all_connections
assert conn2 in all_connections
# Test get_connection (backward compatibility)
single_conn = swarm.get_connection(peer_id)
assert single_conn in [conn1, conn2]
# Test get_connections_map
connections_map = swarm.get_connections_map()
assert peer_id in connections_map
assert connections_map[peer_id] == [conn1, conn2]
@pytest.mark.trio
async def test_swarm_connection_trimming():
"""Test connection trimming when limit is exceeded."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
# Set max connections to 2
connection_config = ConnectionConfig(max_connections_per_peer=2)
swarm = Swarm(
peer_id, peerstore, upgrader, transport, connection_config=connection_config
)
# Add 3 connections
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
conn3 = MockConnection(peer_id)
swarm.connections[peer_id] = [conn1, conn2, conn3]
# Trigger trimming
swarm._trim_connections(peer_id)
# Should have only 2 connections
assert len(swarm.connections[peer_id]) == 2
# The most recent connections should remain
remaining = swarm.connections[peer_id]
assert conn2 in remaining
assert conn3 in remaining
@pytest.mark.trio
async def test_swarm_backward_compatibility():
"""Test backward compatibility features."""
peer_id = ID(b"QmTest")
peerstore = Mock()
upgrader = Mock()
transport = Mock()
swarm = Swarm(peer_id, peerstore, upgrader, transport)
# Add connections
conn1 = MockConnection(peer_id)
conn2 = MockConnection(peer_id)
swarm.connections[peer_id] = [conn1, conn2]
# Test connections_legacy property
legacy_connections = swarm.connections_legacy
assert peer_id in legacy_connections
# Should return first connection
assert legacy_connections[peer_id] in [conn1, conn2]
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -16,6 +16,9 @@ from libp2p.network.exceptions import (
from libp2p.network.swarm import (
Swarm,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from libp2p.tools.utils import (
connect_swarm,
)
@ -48,14 +51,19 @@ async def test_swarm_dial_peer(security_protocol):
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
# New: dial_peer now returns list of connections
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert len(connections) > 0
# Verify connections are established in both directions
assert swarms[0].get_peer_id() in swarms[1].connections
assert swarms[1].get_peer_id() in swarms[0].connections
# Test: Reuse connections when we already have ones with a peer.
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert conn is conn_to_1
existing_connections = swarms[0].get_connections(swarms[1].get_peer_id())
new_connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert new_connections == existing_connections
@pytest.mark.trio
@ -104,7 +112,8 @@ async def test_swarm_close_peer(security_protocol):
@pytest.mark.trio
async def test_swarm_remove_conn(swarm_pair):
swarm_0, swarm_1 = swarm_pair
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
# Get the first connection from the list
conn_0 = swarm_0.connections[swarm_1.get_peer_id()][0]
swarm_0.remove_conn(conn_0)
assert swarm_1.get_peer_id() not in swarm_0.connections
# Test: Remove twice. There should not be errors.
@ -112,6 +121,67 @@ async def test_swarm_remove_conn(swarm_pair):
assert swarm_1.get_peer_id() not in swarm_0.connections
@pytest.mark.trio
async def test_swarm_multiple_connections(security_protocol):
"""Test multiple connections per peer functionality."""
async with SwarmFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as swarms:
# Setup multiple addresses for peer
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
# Dial peer - should return list of connections
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert len(connections) > 0
# Test get_connections method
peer_connections = swarms[0].get_connections(swarms[1].get_peer_id())
assert len(peer_connections) == len(connections)
# Test get_connections_map method
connections_map = swarms[0].get_connections_map()
assert swarms[1].get_peer_id() in connections_map
assert len(connections_map[swarms[1].get_peer_id()]) == len(connections)
# Test get_connection method (backward compatibility)
single_conn = swarms[0].get_connection(swarms[1].get_peer_id())
assert single_conn is not None
assert single_conn in connections
@pytest.mark.trio
async def test_swarm_load_balancing(security_protocol):
"""Test load balancing across multiple connections."""
async with SwarmFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as swarms:
# Setup connection
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
# Create multiple streams - should use load balancing
streams = []
for _ in range(5):
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
streams.append(stream)
# Verify streams were created successfully
assert len(streams) == 5
# Clean up
for stream in streams:
await stream.close()
@pytest.mark.trio
async def test_swarm_multiaddr(security_protocol):
async with SwarmFactory.create_batch_and_listen(
@ -184,3 +254,116 @@ def test_new_swarm_quic_multiaddr_raises():
addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic")
with pytest.raises(ValueError, match="QUIC not yet supported"):
new_swarm(listen_addrs=[addr])
@pytest.mark.trio
async def test_swarm_listen_multiple_addresses(security_protocol):
"""Test that swarm can listen on multiple addresses simultaneously."""
from libp2p.utils.address_validation import get_available_interfaces
# Get multiple addresses to listen on
listen_addrs = get_available_interfaces(0) # Let OS choose ports
# Create a swarm and listen on multiple addresses
swarm = SwarmFactory.build(security_protocol=security_protocol)
async with background_trio_service(swarm):
# Listen on all addresses
success = await swarm.listen(*listen_addrs)
assert success, "Should successfully listen on at least one address"
# Check that we have listeners for the addresses
actual_listeners = list(swarm.listeners.keys())
assert len(actual_listeners) > 0, "Should have at least one listener"
# Verify that all successful listeners are in the listeners dict
successful_count = 0
for addr in listen_addrs:
addr_str = str(addr)
if addr_str in actual_listeners:
successful_count += 1
# This address successfully started listening
listener = swarm.listeners[addr_str]
listener_addrs = listener.get_addrs()
assert len(listener_addrs) > 0, (
f"Listener for {addr} should have addresses"
)
# Check that the listener address matches the expected address
# (port might be different if we used port 0)
expected_ip = addr.value_for_protocol("ip4")
expected_protocol = addr.value_for_protocol("tcp")
if expected_ip and expected_protocol:
found_matching = False
for listener_addr in listener_addrs:
if (
listener_addr.value_for_protocol("ip4") == expected_ip
and listener_addr.value_for_protocol("tcp") is not None
):
found_matching = True
break
assert found_matching, (
f"Listener for {addr} should have matching IP"
)
assert successful_count == len(listen_addrs), (
f"All {len(listen_addrs)} addresses should be listening, "
f"but only {successful_count} succeeded"
)
@pytest.mark.trio
async def test_swarm_listen_multiple_addresses_connectivity(security_protocol):
"""Test that real libp2p connections can be established to all listening addresses.""" # noqa: E501
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.utils.address_validation import get_available_interfaces
# Get multiple addresses to listen on
listen_addrs = get_available_interfaces(0) # Let OS choose ports
# Create a swarm and listen on multiple addresses
swarm1 = SwarmFactory.build(security_protocol=security_protocol)
async with background_trio_service(swarm1):
# Listen on all addresses
success = await swarm1.listen(*listen_addrs)
assert success, "Should successfully listen on at least one address"
# Verify all available interfaces are listening
assert len(swarm1.listeners) == len(listen_addrs), (
f"All {len(listen_addrs)} interfaces should be listening, "
f"but only {len(swarm1.listeners)} are"
)
# Create a second swarm to test connections
swarm2 = SwarmFactory.build(security_protocol=security_protocol)
async with background_trio_service(swarm2):
# Test connectivity to each listening address using real libp2p connections
for addr_str, listener in swarm1.listeners.items():
listener_addrs = listener.get_addrs()
for listener_addr in listener_addrs:
# Create a full multiaddr with peer ID for libp2p connection
peer_id = swarm1.get_peer_id()
full_addr = listener_addr.encapsulate(f"/p2p/{peer_id}")
# Test real libp2p connection
try:
peer_info = info_from_p2p_addr(full_addr)
# Add the peer info to swarm2's peerstore so it knows where to connect # noqa: E501
swarm2.peerstore.add_addrs(
peer_info.peer_id, [listener_addr], 10000
)
await swarm2.dial_peer(peer_info.peer_id)
# Verify connection was established
assert peer_info.peer_id in swarm2.connections, (
f"Connection to {full_addr} should be established"
)
assert swarm2.get_peer_id() in swarm1.connections, (
f"Connection from {full_addr} should be established"
)
except Exception as e:
pytest.fail(
f"Failed to establish libp2p connection to {full_addr}: {e}"
)

View File

@ -8,8 +8,10 @@ from typing import (
from unittest.mock import patch
import pytest
import multiaddr
import trio
from libp2p.crypto.rsa import create_new_key_pair
from libp2p.custom_types import AsyncValidatorFn
from libp2p.exceptions import (
ValidationError,
@ -17,9 +19,11 @@ from libp2p.exceptions import (
from libp2p.network.stream.exceptions import (
StreamEOF,
)
from libp2p.peer.envelope import Envelope, seal_record
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peer_record import PeerRecord
from libp2p.pubsub.pb import (
rpc_pb2,
)
@ -87,6 +91,45 @@ async def test_re_unsubscribe():
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
@pytest.mark.trio
async def test_reissue_when_listen_addrs_change():
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Yield to let 0 notify 1
await trio.sleep(1)
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
# Check whether signed-records were transfered properly in the subscribe call
envelope_b_sub = (
pubsubs_fsub[1]
.host.get_peerstore()
.get_peer_record(pubsubs_fsub[0].host.get_id())
)
assert isinstance(envelope_b_sub, Envelope)
# Simulate pubsubs_fsub[1].host listen addrs changing (different port)
new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
# Patch just for the duration we force A to unsubscribe
with patch.object(pubsubs_fsub[0].host, "get_addrs", return_value=[new_addr]):
# Unsubscribe from A's side so that a new_record is issued
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
await trio.sleep(1)
# B should be holding A's new record with bumped seq
envelope_b_unsub = (
pubsubs_fsub[1]
.host.get_peerstore()
.get_peer_record(pubsubs_fsub[0].host.get_id())
)
assert isinstance(envelope_b_unsub, Envelope)
# This proves that a freshly signed record was issued rather than
# the latest-cached-one creating one.
assert envelope_b_sub.record().seq < envelope_b_unsub.record().seq
@pytest.mark.trio
async def test_peers_subscribe():
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
@ -95,11 +138,71 @@ async def test_peers_subscribe():
# Yield to let 0 notify 1
await trio.sleep(1)
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
# Check whether signed-records were transfered properly in the subscribe call
envelope_b_sub = (
pubsubs_fsub[1]
.host.get_peerstore()
.get_peer_record(pubsubs_fsub[0].host.get_id())
)
assert isinstance(envelope_b_sub, Envelope)
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
# Yield to let 0 notify 1
await trio.sleep(1)
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
envelope_b_unsub = (
pubsubs_fsub[1]
.host.get_peerstore()
.get_peer_record(pubsubs_fsub[0].host.get_id())
)
assert isinstance(envelope_b_unsub, Envelope)
# This proves that the latest-cached-record was re-issued rather than
# freshly creating one.
assert envelope_b_sub.record().seq == envelope_b_unsub.record().seq
@pytest.mark.trio
async def test_peer_subscribe_fail_upon_invald_record_transfer():
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
# Corrupt host_a's local peer record
envelope = pubsubs_fsub[0].host.get_peerstore().get_local_record()
if envelope is not None:
true_record = envelope.record()
key_pair = create_new_key_pair()
if envelope is not None:
envelope.public_key = key_pair.public_key
pubsubs_fsub[0].host.get_peerstore().set_local_record(envelope)
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Yeild to let 0 notify 1
await trio.sleep(1)
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics.get(
TESTING_TOPIC, set()
)
# Create a corrupt envelope with correct signature but false peer-id
false_record = PeerRecord(
ID.from_pubkey(key_pair.public_key), true_record.addrs
)
false_envelope = seal_record(
false_record, pubsubs_fsub[0].host.get_private_key()
)
pubsubs_fsub[0].host.get_peerstore().set_local_record(false_envelope)
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Yeild to let 0 notify 1
await trio.sleep(1)
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics.get(
TESTING_TOPIC, set()
)
@pytest.mark.trio
async def test_get_hello_packet():

View File

@ -51,6 +51,9 @@ async def perform_simple_test(assertion_func, security_protocol):
# Extract the secured connection from either Mplex or Yamux implementation
def get_secured_conn(conn):
# conn is now a list, get the first connection
if isinstance(conn, list):
conn = conn[0]
muxed_conn = conn.muxed_conn
# Direct attribute access for known implementations
has_secured_conn = hasattr(muxed_conn, "secured_conn")

View File

@ -74,7 +74,8 @@ async def test_multiplexer_preference_parameter(muxer_preference):
assert len(connections) > 0, "Connection not established"
# Get the first connection
conn = list(connections.values())[0]
conns = list(connections.values())[0]
conn = conns[0] # Get first connection from the list
muxed_conn = conn.muxed_conn
# Define a simple echo protocol
@ -150,7 +151,8 @@ async def test_explicit_muxer_options(muxer_option_func, expected_stream_class):
assert len(connections) > 0, "Connection not established"
# Get the first connection
conn = list(connections.values())[0]
conns = list(connections.values())[0]
conn = conns[0] # Get first connection from the list
muxed_conn = conn.muxed_conn
# Define a simple echo protocol
@ -219,7 +221,8 @@ async def test_global_default_muxer(global_default):
assert len(connections) > 0, "Connection not established"
# Get the first connection
conn = list(connections.values())[0]
conns = list(connections.values())[0]
conn = conns[0] # Get first connection from the list
muxed_conn = conn.muxed_conn
# Define a simple echo protocol

View File

@ -0,0 +1,109 @@
import contextlib
import os
from pathlib import Path
import subprocess
import sys
import time
from multiaddr import Multiaddr
from multiaddr.protocols import P_IP4, P_IP6, P_P2P, P_TCP
# pytestmark = pytest.mark.timeout(20) # Temporarily disabled for debugging
# This test is intentionally lightweight and can be marked as 'integration'.
# It ensures the echo example runs and prints the new Thin Waist lines using
# Trio primitives.
current_file = Path(__file__)
project_root = current_file.parent.parent.parent
EXAMPLES_DIR: Path = project_root / "examples" / "echo"
def test_echo_example_starts_and_prints_thin_waist(monkeypatch, tmp_path):
"""Run echo server and validate printed multiaddr and peer id."""
# Run echo example as server
cmd = [sys.executable, "-u", str(EXAMPLES_DIR / "echo.py"), "-p", "0"]
env = {**os.environ, "PYTHONUNBUFFERED": "1"}
proc: subprocess.Popen[str] = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
env=env,
)
if proc.stdout is None:
proc.terminate()
raise RuntimeError("Process stdout is None")
out_stream = proc.stdout
peer_id: str | None = None
printed_multiaddr: str | None = None
saw_waiting = False
start = time.time()
timeout_s = 8.0
try:
while time.time() - start < timeout_s:
line = out_stream.readline()
if not line:
time.sleep(0.05)
continue
s = line.strip()
if s.startswith("I am "):
peer_id = s.partition("I am ")[2]
if s.startswith("echo-demo -d "):
printed_multiaddr = s.partition("echo-demo -d ")[2]
if "Waiting for incoming connections..." in s:
saw_waiting = True
break
finally:
with contextlib.suppress(ProcessLookupError):
proc.terminate()
with contextlib.suppress(ProcessLookupError):
proc.kill()
assert peer_id, "Did not capture peer ID line"
assert printed_multiaddr, "Did not capture multiaddr line"
assert saw_waiting, "Did not capture waiting-for-connections line"
# Validate multiaddr structure using py-multiaddr protocol methods
ma = Multiaddr(printed_multiaddr) # should parse without error
# Check that the multiaddr contains the p2p protocol
try:
peer_id_from_multiaddr = ma.value_for_protocol("p2p")
assert peer_id_from_multiaddr is not None, (
"Multiaddr missing p2p protocol value"
)
assert peer_id_from_multiaddr == peer_id, (
f"Peer ID mismatch: {peer_id_from_multiaddr} != {peer_id}"
)
except Exception as e:
raise AssertionError(f"Failed to extract p2p protocol value: {e}")
# Validate the multiaddr structure by checking protocols
protocols = ma.protocols()
# Should have at least IP, TCP, and P2P protocols
assert any(p.code == P_IP4 or p.code == P_IP6 for p in protocols), (
"Missing IP protocol"
)
assert any(p.code == P_TCP for p in protocols), "Missing TCP protocol"
assert any(p.code == P_P2P for p in protocols), "Missing P2P protocol"
# Extract the p2p part and validate it matches the captured peer ID
p2p_part = Multiaddr(f"/p2p/{peer_id}")
try:
# Decapsulate the p2p part to get the transport address
transport_addr = ma.decapsulate(p2p_part)
# Verify the decapsulated address doesn't contain p2p
transport_protocols = transport_addr.protocols()
assert not any(p.code == P_P2P for p in transport_protocols), (
"Decapsulation failed - still contains p2p"
)
# Verify the original multiaddr can be reconstructed
reconstructed = transport_addr.encapsulate(p2p_part)
assert str(reconstructed) == str(ma), "Reconstruction failed"
except Exception as e:
raise AssertionError(f"Multiaddr decapsulation failed: {e}")

View File

@ -669,8 +669,8 @@ async def swarm_conn_pair_factory(
async with swarm_pair_factory(
security_protocol=security_protocol, muxer_opt=muxer_opt
) as swarms:
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
conn_0 = swarms[0].connections[swarms[1].get_peer_id()][0]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()][0]
yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)

View File

@ -0,0 +1,56 @@
import os
import pytest
from multiaddr import Multiaddr
from libp2p.utils.address_validation import (
expand_wildcard_address,
get_available_interfaces,
get_optimal_binding_address,
)
@pytest.mark.parametrize("proto", ["tcp"])
def test_get_available_interfaces(proto: str) -> None:
interfaces = get_available_interfaces(0, protocol=proto)
assert len(interfaces) > 0
for addr in interfaces:
assert isinstance(addr, Multiaddr)
assert f"/{proto}/" in str(addr)
def test_get_optimal_binding_address() -> None:
addr = get_optimal_binding_address(0)
assert isinstance(addr, Multiaddr)
# At least IPv4 or IPv6 prefix present
s = str(addr)
assert ("/ip4/" in s) or ("/ip6/" in s)
def test_expand_wildcard_address_ipv4() -> None:
wildcard = Multiaddr("/ip4/0.0.0.0/tcp/0")
expanded = expand_wildcard_address(wildcard)
assert len(expanded) > 0
for e in expanded:
assert isinstance(e, Multiaddr)
assert "/tcp/" in str(e)
def test_expand_wildcard_address_port_override() -> None:
wildcard = Multiaddr("/ip4/0.0.0.0/tcp/7000")
overridden = expand_wildcard_address(wildcard, port=9001)
assert len(overridden) > 0
for e in overridden:
assert str(e).endswith("/tcp/9001")
@pytest.mark.skipif(
os.environ.get("NO_IPV6") == "1",
reason="Environment disallows IPv6",
)
def test_expand_wildcard_address_ipv6() -> None:
wildcard = Multiaddr("/ip6/::/tcp/0")
expanded = expand_wildcard_address(wildcard)
assert len(expanded) > 0
for e in expanded:
assert "/ip6/" in str(e)

View File

@ -15,6 +15,7 @@ import pytest
import trio
from libp2p.utils.logging import (
_current_handlers,
_current_listener,
_listener_ready,
log_queue,
@ -24,13 +25,19 @@ from libp2p.utils.logging import (
def _reset_logging():
"""Reset all logging state."""
global _current_listener, _listener_ready
global _current_listener, _listener_ready, _current_handlers
# Stop existing listener if any
if _current_listener is not None:
_current_listener.stop()
_current_listener = None
# Close all file handlers to ensure proper cleanup on Windows
for handler in _current_handlers:
if isinstance(handler, logging.FileHandler):
handler.close()
_current_handlers.clear()
# Reset the event
_listener_ready = threading.Event()
@ -174,6 +181,15 @@ async def test_custom_log_file(clean_env):
if _current_listener is not None:
_current_listener.stop()
# Give a moment for the listener to fully stop
await trio.sleep(0.05)
# Close all file handlers to release the file
for handler in _current_handlers:
if isinstance(handler, logging.FileHandler):
handler.flush() # Ensure all writes are flushed
handler.close()
# Check if the file exists and contains our message
assert log_file.exists()
content = log_file.read_text()
@ -185,16 +201,15 @@ async def test_default_log_file(clean_env):
"""Test logging to the default file path."""
os.environ["LIBP2P_DEBUG"] = "INFO"
with patch("libp2p.utils.logging.datetime") as mock_datetime:
# Mock the timestamp to have a predictable filename
mock_datetime.now.return_value.strftime.return_value = "20240101_120000"
with patch("libp2p.utils.paths.create_temp_file") as mock_create_temp:
# Mock the temp file creation to return a predictable path
mock_temp_file = (
Path(tempfile.gettempdir()) / "test_py-libp2p_20240101_120000.log"
)
mock_create_temp.return_value = mock_temp_file
# Remove the log file if it exists
if os.name == "nt": # Windows
log_file = Path("C:/Windows/Temp/20240101_120000_py-libp2p.log")
else: # Unix-like
log_file = Path("/tmp/20240101_120000_py-libp2p.log")
log_file.unlink(missing_ok=True)
mock_temp_file.unlink(missing_ok=True)
setup_logging()
@ -211,9 +226,18 @@ async def test_default_log_file(clean_env):
if _current_listener is not None:
_current_listener.stop()
# Check the default log file
if log_file.exists(): # Only check content if we have write permission
content = log_file.read_text()
# Give a moment for the listener to fully stop
await trio.sleep(0.05)
# Close all file handlers to release the file
for handler in _current_handlers:
if isinstance(handler, logging.FileHandler):
handler.flush() # Ensure all writes are flushed
handler.close()
# Check the mocked temp file
if mock_temp_file.exists():
content = mock_temp_file.read_text()
assert "Test message" in content

290
tests/utils/test_paths.py Normal file
View File

@ -0,0 +1,290 @@
"""
Tests for cross-platform path utilities.
"""
import os
from pathlib import Path
import tempfile
import pytest
from libp2p.utils.paths import (
create_temp_file,
ensure_dir_exists,
find_executable,
get_binary_path,
get_config_dir,
get_project_root,
get_python_executable,
get_script_binary_path,
get_script_dir,
get_temp_dir,
get_venv_path,
join_paths,
normalize_path,
resolve_relative_path,
)
class TestPathUtilities:
"""Test cross-platform path utilities."""
def test_get_temp_dir(self):
"""Test that temp directory is accessible and exists."""
temp_dir = get_temp_dir()
assert isinstance(temp_dir, Path)
assert temp_dir.exists()
assert temp_dir.is_dir()
# Should match system temp directory
assert temp_dir == Path(tempfile.gettempdir())
def test_get_project_root(self):
"""Test that project root is correctly determined."""
project_root = get_project_root()
assert isinstance(project_root, Path)
assert project_root.exists()
# Should contain pyproject.toml
assert (project_root / "pyproject.toml").exists()
# Should contain libp2p directory
assert (project_root / "libp2p").exists()
def test_join_paths(self):
"""Test cross-platform path joining."""
# Test with strings
result = join_paths("a", "b", "c")
expected = Path("a") / "b" / "c"
assert result == expected
# Test with mixed types
result = join_paths("a", Path("b"), "c")
expected = Path("a") / "b" / "c"
assert result == expected
# Test with absolute path
result = join_paths("/absolute", "path")
expected = Path("/absolute") / "path"
assert result == expected
def test_ensure_dir_exists(self, tmp_path):
"""Test directory creation and existence checking."""
# Test creating new directory
new_dir = tmp_path / "new_dir"
result = ensure_dir_exists(new_dir)
assert result == new_dir
assert new_dir.exists()
assert new_dir.is_dir()
# Test creating nested directory
nested_dir = tmp_path / "parent" / "child" / "grandchild"
result = ensure_dir_exists(nested_dir)
assert result == nested_dir
assert nested_dir.exists()
assert nested_dir.is_dir()
# Test with existing directory
result = ensure_dir_exists(new_dir)
assert result == new_dir
assert new_dir.exists()
def test_get_config_dir(self):
"""Test platform-specific config directory."""
config_dir = get_config_dir()
assert isinstance(config_dir, Path)
if os.name == "nt": # Windows
# Should be in AppData/Roaming or user home
assert "AppData" in str(config_dir) or "py-libp2p" in str(config_dir)
else: # Unix-like
# Should be in ~/.config
assert ".config" in str(config_dir)
assert "py-libp2p" in str(config_dir)
def test_get_script_dir(self):
"""Test script directory detection."""
# Test with current file
script_dir = get_script_dir(__file__)
assert isinstance(script_dir, Path)
assert script_dir.exists()
assert script_dir.is_dir()
# Should contain this test file
assert (script_dir / "test_paths.py").exists()
def test_create_temp_file(self):
"""Test temporary file creation."""
temp_file = create_temp_file()
assert isinstance(temp_file, Path)
assert temp_file.parent == get_temp_dir()
assert temp_file.name.startswith("py-libp2p_")
assert temp_file.name.endswith(".log")
# Test with custom prefix and suffix
temp_file = create_temp_file(prefix="test_", suffix=".txt")
assert temp_file.name.startswith("test_")
assert temp_file.name.endswith(".txt")
def test_resolve_relative_path(self, tmp_path):
"""Test relative path resolution."""
base_path = tmp_path / "base"
base_path.mkdir()
# Test relative path
relative_path = "subdir/file.txt"
result = resolve_relative_path(base_path, relative_path)
expected = (base_path / "subdir" / "file.txt").resolve()
assert result == expected
# Test absolute path (platform-agnostic)
if os.name == "nt": # Windows
absolute_path = "C:\\absolute\\path"
else: # Unix-like
absolute_path = "/absolute/path"
result = resolve_relative_path(base_path, absolute_path)
assert result == Path(absolute_path)
def test_normalize_path(self, tmp_path):
"""Test path normalization."""
# Test with relative path
relative_path = tmp_path / ".." / "normalize_test"
result = normalize_path(relative_path)
assert result.is_absolute()
assert "normalize_test" in str(result)
# Test with absolute path
absolute_path = tmp_path / "test_file"
result = normalize_path(absolute_path)
assert result.is_absolute()
assert result == absolute_path.resolve()
def test_get_venv_path(self, monkeypatch):
"""Test virtual environment path detection."""
# Test when no virtual environment is active
# Temporarily clear VIRTUAL_ENV to test the "no venv" case
monkeypatch.delenv("VIRTUAL_ENV", raising=False)
result = get_venv_path()
assert result is None
# Test when virtual environment is active
test_venv_path = "/path/to/venv"
monkeypatch.setenv("VIRTUAL_ENV", test_venv_path)
result = get_venv_path()
assert result == Path(test_venv_path)
def test_get_python_executable(self):
"""Test Python executable path detection."""
result = get_python_executable()
assert isinstance(result, Path)
assert result.exists()
assert result.name.startswith("python")
def test_find_executable(self):
"""Test executable finding in PATH."""
# Test with non-existent executable
result = find_executable("nonexistent_executable")
assert result is None
# Test with existing executable (python should be available)
result = find_executable("python")
if result:
assert isinstance(result, Path)
assert result.exists()
def test_get_script_binary_path(self):
"""Test script binary path detection."""
result = get_script_binary_path()
assert isinstance(result, Path)
assert result.exists()
assert result.is_dir()
def test_get_binary_path(self, monkeypatch):
"""Test binary path resolution with virtual environment."""
# Test when no virtual environment is active
result = get_binary_path("python")
if result:
assert isinstance(result, Path)
assert result.exists()
# Test when virtual environment is active
test_venv_path = "/path/to/venv"
monkeypatch.setenv("VIRTUAL_ENV", test_venv_path)
# This test is more complex as it depends on the actual venv structure
# We'll just verify the function doesn't crash
result = get_binary_path("python")
# Result can be None if binary not found in venv
if result:
assert isinstance(result, Path)
class TestCrossPlatformCompatibility:
"""Test cross-platform compatibility."""
def test_config_dir_platform_specific_windows(self, monkeypatch):
"""Test config directory respects Windows conventions."""
import platform
# Only run this test on Windows systems
if platform.system() != "Windows":
pytest.skip("This test only runs on Windows systems")
monkeypatch.setattr("os.name", "nt")
monkeypatch.setenv("APPDATA", "C:\\Users\\Test\\AppData\\Roaming")
config_dir = get_config_dir()
assert "AppData" in str(config_dir)
assert "py-libp2p" in str(config_dir)
def test_path_separators_consistent(self):
"""Test that path separators are handled consistently."""
# Test that join_paths uses platform-appropriate separators
result = join_paths("dir1", "dir2", "file.txt")
expected = Path("dir1") / "dir2" / "file.txt"
assert result == expected
# Test that the result uses correct separators for the platform
if os.name == "nt": # Windows
assert "\\" in str(result) or "/" in str(result)
else: # Unix-like
assert "/" in str(result)
def test_temp_file_uniqueness(self):
"""Test that temporary files have unique names."""
files = set()
for _ in range(10):
temp_file = create_temp_file()
assert temp_file not in files
files.add(temp_file)
class TestBackwardCompatibility:
"""Test backward compatibility with existing code patterns."""
def test_path_operations_equivalent(self):
"""Test that new path operations are equivalent to old os.path operations."""
# Test join_paths vs os.path.join
parts = ["a", "b", "c"]
new_result = join_paths(*parts)
old_result = Path(os.path.join(*parts))
assert new_result == old_result
# Test get_script_dir vs os.path.dirname(os.path.abspath(__file__))
new_script_dir = get_script_dir(__file__)
old_script_dir = Path(os.path.dirname(os.path.abspath(__file__)))
assert new_script_dir == old_script_dir
def test_existing_functionality_preserved(self):
"""Ensure no existing functionality is broken."""
# Test that all functions return Path objects
assert isinstance(get_temp_dir(), Path)
assert isinstance(get_project_root(), Path)
assert isinstance(join_paths("a", "b"), Path)
assert isinstance(ensure_dir_exists(tempfile.gettempdir()), Path)
assert isinstance(get_config_dir(), Path)
assert isinstance(get_script_dir(__file__), Path)
assert isinstance(create_temp_file(), Path)
assert isinstance(resolve_relative_path(".", "test"), Path)
assert isinstance(normalize_path("."), Path)
assert isinstance(get_python_executable(), Path)
assert isinstance(get_script_binary_path(), Path)
# Test optional return types
venv_path = get_venv_path()
if venv_path is not None:
assert isinstance(venv_path, Path)