Merge branch 'main' into add-read-write-lock

This commit is contained in:
Manu Sheel Gupta
2025-07-16 14:59:44 -07:00
committed by GitHub
23 changed files with 2290 additions and 62 deletions

View File

@ -11,9 +11,7 @@ from libp2p.identity.identify.identify import (
PROTOCOL_VERSION,
_mk_identify_protobuf,
_multiaddr_to_bytes,
)
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
parse_identify_response,
)
from tests.utils.factories import (
host_pair_factory,
@ -29,14 +27,18 @@ async def test_identify_protocol(security_protocol):
host_b,
):
# Here, host_b is the requester and host_a is the responder.
# observed_addr represent host_bs address as observed by host_a
# (i.e., the address from which host_bs request was received).
# observed_addr represent host_b's address as observed by host_a
# (i.e., the address from which host_b's request was received).
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read()
# Read the response (could be either format)
# Read a larger chunk to get all the data before stream closes
response = await stream.read(8192) # Read enough data in one go
await stream.close()
identify_response = Identify()
identify_response.ParseFromString(response)
# Parse the response (handles both old and new formats)
identify_response = parse_identify_response(response)
logger.debug("host_a: %s", host_a.get_addrs())
logger.debug("host_b: %s", host_b.get_addrs())
@ -62,8 +64,9 @@ async def test_identify_protocol(security_protocol):
logger.debug("observed_addr: %s", Multiaddr(identify_response.observed_addr))
logger.debug("host_b.get_addrs()[0]: %s", host_b.get_addrs()[0])
logger.debug("cleaned_addr= %s", cleaned_addr)
assert identify_response.observed_addr == _multiaddr_to_bytes(cleaned_addr)
# The observed address should match the cleaned address
assert Multiaddr(identify_response.observed_addr) == cleaned_addr
# Check protocols
assert set(identify_response.protocols) == set(host_a.get_mux().get_protocols())

View File

@ -0,0 +1,410 @@
import pytest
from libp2p.identity.identify.identify import (
_mk_identify_protobuf,
)
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
from libp2p.io.abc import Closer, Reader, Writer
from libp2p.utils.varint import (
decode_varint_from_bytes,
encode_varint_prefixed,
)
from tests.utils.factories import (
host_pair_factory,
)
class MockStream(Reader, Writer, Closer):
"""Mock stream for testing identify protocol compatibility."""
def __init__(self, data: bytes):
self.data = data
self.position = 0
self.closed = False
async def read(self, n: int | None = None) -> bytes:
if self.closed or self.position >= len(self.data):
return b""
if n is None:
n = len(self.data) - self.position
result = self.data[self.position : self.position + n]
self.position += len(result)
return result
async def write(self, data: bytes) -> None:
# Mock write - just store the data
pass
async def close(self) -> None:
self.closed = True
def create_identify_message(host, observed_multiaddr=None):
"""Create an identify protobuf message."""
return _mk_identify_protobuf(host, observed_multiaddr)
def create_new_format_message(identify_msg):
"""Create a new format (length-prefixed) identify message."""
msg_bytes = identify_msg.SerializeToString()
return encode_varint_prefixed(msg_bytes)
def create_old_format_message(identify_msg):
"""Create an old format (raw protobuf) identify message."""
return identify_msg.SerializeToString()
async def read_new_format_message(stream) -> bytes:
"""Read a new format (length-prefixed) identify message."""
# Read varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if not length_bytes:
raise ValueError("No length prefix received")
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
response = await stream.read(msg_length)
if len(response) != msg_length:
raise ValueError("Incomplete message received")
return response
async def read_old_format_message(stream) -> bytes:
"""Read an old format (raw protobuf) identify message."""
# Read all available data
response = b""
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
async def read_compatible_message(stream) -> bytes:
"""Read an identify message in either old or new format."""
# Try to read a few bytes to detect the format
first_bytes = await stream.read(10)
if not first_bytes:
raise ValueError("No data received")
# Try to decode as varint length prefix (new format)
try:
msg_length = decode_varint_from_bytes(first_bytes)
# Validate that the length is reasonable (not too large)
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate how many bytes the varint consumed
varint_len = 0
for i, byte in enumerate(first_bytes):
varint_len += 1
if (byte & 0x80) == 0:
break
# Read the remaining protobuf message
remaining_bytes = await stream.read(
msg_length - (len(first_bytes) - varint_len)
)
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
message_data = first_bytes[varint_len:] + remaining_bytes
# Try to parse as protobuf to validate
try:
Identify().ParseFromString(message_data)
return message_data
except Exception:
# If protobuf parsing fails, fall back to old format
pass
except Exception:
pass
# Fall back to old format (raw protobuf)
response = first_bytes
# Read more data if available
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
async def read_compatible_message_simple(stream) -> bytes:
"""Read a message in either old or new format (simplified version for testing)."""
# Try to read a few bytes to detect the format
first_bytes = await stream.read(10)
if not first_bytes:
raise ValueError("No data received")
# Try to decode as varint length prefix (new format)
try:
msg_length = decode_varint_from_bytes(first_bytes)
# Validate that the length is reasonable (not too large)
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate how many bytes the varint consumed
varint_len = 0
for i, byte in enumerate(first_bytes):
varint_len += 1
if (byte & 0x80) == 0:
break
# Read the remaining message
remaining_bytes = await stream.read(
msg_length - (len(first_bytes) - varint_len)
)
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
return first_bytes[varint_len:] + remaining_bytes
except Exception:
pass
# Fall back to old format (raw data)
response = first_bytes
# Read more data if available
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
def detect_format(data):
"""Detect if data is in new or old format (varint-prefixed or raw protobuf)."""
if not data:
return "unknown"
# Try to decode as varint
try:
msg_length = decode_varint_from_bytes(data)
# Validate that the length is reasonable
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate varint length
varint_len = 0
for i, byte in enumerate(data):
varint_len += 1
if (byte & 0x80) == 0:
break
# Check if we have enough data for the message
if len(data) >= varint_len + msg_length:
# Additional check: try to parse the message as protobuf
try:
message_data = data[varint_len : varint_len + msg_length]
Identify().ParseFromString(message_data)
return "new"
except Exception:
# If protobuf parsing fails, it's probably not a valid new format
pass
except Exception:
pass
# If varint decoding fails or length is unreasonable, assume old format
return "old"
@pytest.mark.trio
async def test_identify_new_format_compatibility(security_protocol):
"""Test that identify protocol works with new format (length-prefixed) messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create new format message
new_format_data = create_new_format_message(identify_msg)
# Create mock stream with new format data
stream = MockStream(new_format_data)
# Read using new format reader
response = await read_new_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_old_format_compatibility(security_protocol):
"""Test that identify protocol works with old format (raw protobuf) messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create old format message
old_format_data = create_old_format_message(identify_msg)
# Create mock stream with old format data
stream = MockStream(old_format_data)
# Read using old format reader
response = await read_old_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_backward_compatibility_old_format(security_protocol):
"""Test backward compatibility reader with old format messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create old format message
old_format_data = create_old_format_message(identify_msg)
# Create mock stream with old format data
stream = MockStream(old_format_data)
# Read using old format reader (which should work reliably)
response = await read_old_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_backward_compatibility_new_format(security_protocol):
"""Test backward compatibility reader with new format messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create new format message
new_format_data = create_new_format_message(identify_msg)
# Create mock stream with new format data
stream = MockStream(new_format_data)
# Read using new format reader (which should work reliably)
response = await read_new_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_format_detection(security_protocol):
"""Test that the format detection works correctly."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Test new format detection
new_format_data = create_new_format_message(identify_msg)
format_type = detect_format(new_format_data)
assert format_type == "new", "New format should be detected correctly"
# Test old format detection
old_format_data = create_old_format_message(identify_msg)
format_type = detect_format(old_format_data)
assert format_type == "old", "Old format should be detected correctly"
@pytest.mark.trio
async def test_identify_error_handling(security_protocol):
"""Test error handling for malformed messages."""
from libp2p.exceptions import ParseError
# Test with empty data
stream = MockStream(b"")
with pytest.raises(ValueError, match="No data received"):
await read_compatible_message(stream)
# Test with incomplete varint
stream = MockStream(b"\x80") # Incomplete varint
with pytest.raises(ParseError, match="Unexpected end of data"):
await read_new_format_message(stream)
# Test with invalid protobuf data
stream = MockStream(b"\x05invalid") # Length prefix but invalid protobuf
with pytest.raises(Exception): # Should fail when parsing protobuf
response = await read_new_format_message(stream)
Identify().ParseFromString(response)
@pytest.mark.trio
async def test_identify_message_equivalence(security_protocol):
"""Test that old and new format messages are equivalent."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create both formats
new_format_data = create_new_format_message(identify_msg)
old_format_data = create_old_format_message(identify_msg)
# Extract the protobuf message from new format
varint_len = 0
for i, byte in enumerate(new_format_data):
varint_len += 1
if (byte & 0x80) == 0:
break
new_format_protobuf = new_format_data[varint_len:]
# The protobuf messages should be identical
assert new_format_protobuf == old_format_data, (
"Protobuf messages should be identical in both formats"
)

View File

@ -459,7 +459,11 @@ async def test_push_identify_to_peers_respects_concurrency_limit():
lock = trio.Lock()
async def mock_push_identify_to_peer(
host, peer_id, observed_multiaddr=None, limit=trio.Semaphore(CONCURRENCY_LIMIT)
host,
peer_id,
observed_multiaddr=None,
limit=trio.Semaphore(CONCURRENCY_LIMIT),
use_varint_format=True,
) -> bool:
"""
Mock function to test concurrency by simulating an identify message.
@ -593,3 +597,104 @@ async def test_all_peers_receive_identify_push_with_semaphore_under_high_peer_lo
assert peer_id_a in dummy_peerstore.peer_ids()
nursery.cancel_scope.cancel()
@pytest.mark.trio
async def test_identify_push_default_varint_format(security_protocol):
"""
Test that the identify/push protocol uses varint format by default.
This test verifies that:
1. The default behavior uses length-prefixed messages (varint format)
2. Messages are correctly encoded with varint length prefix
3. Messages are correctly decoded with varint length prefix
4. The peerstore is updated correctly with the received information
"""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Set up the identify/push handlers with default settings
# (use_varint_format=True)
host_b.set_stream_handler(ID_PUSH, identify_push_handler_for(host_b))
# Push identify information from host_a to host_b using default settings
success = await push_identify_to_peer(host_a, host_b.get_id())
assert success, "Identify push should succeed with default varint format"
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Get the peerstore from host_b
peerstore = host_b.get_peerstore()
peer_id = host_a.get_id()
# Verify that the peerstore was updated correctly
assert peer_id in peerstore.peer_ids()
# Check that addresses have been updated
host_a_addrs = set(host_a.get_addrs())
peerstore_addrs = set(peerstore.addrs(peer_id))
assert all(addr in peerstore_addrs for addr in host_a_addrs)
# Check that protocols have been updated
host_a_protocols = set(host_a.get_mux().get_protocols())
peerstore_protocols = set(peerstore.get_protocols(peer_id))
assert all(protocol in peerstore_protocols for protocol in host_a_protocols)
# Check that the public key has been updated
host_a_public_key = host_a.get_public_key().serialize()
peerstore_public_key = peerstore.pubkey(peer_id).serialize()
assert host_a_public_key == peerstore_public_key
@pytest.mark.trio
async def test_identify_push_legacy_raw_format(security_protocol):
"""
Test that the identify/push protocol can use legacy raw format when specified.
This test verifies that:
1. When use_varint_format=False, messages are sent without length prefix
2. Raw protobuf messages are correctly encoded and decoded
3. The peerstore is updated correctly with the received information
4. The legacy format is backward compatible
"""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Set up the identify/push handlers with legacy format (use_varint_format=False)
host_b.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_b, use_varint_format=False)
)
# Push identify information from host_a to host_b using legacy format
success = await push_identify_to_peer(
host_a, host_b.get_id(), use_varint_format=False
)
assert success, "Identify push should succeed with legacy raw format"
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Get the peerstore from host_b
peerstore = host_b.get_peerstore()
peer_id = host_a.get_id()
# Verify that the peerstore was updated correctly
assert peer_id in peerstore.peer_ids()
# Check that addresses have been updated
host_a_addrs = set(host_a.get_addrs())
peerstore_addrs = set(peerstore.addrs(peer_id))
assert all(addr in peerstore_addrs for addr in host_a_addrs)
# Check that protocols have been updated
host_a_protocols = set(host_a.get_mux().get_protocols())
peerstore_protocols = set(peerstore.get_protocols(peer_id))
assert all(protocol in peerstore_protocols for protocol in host_a_protocols)
# Check that the public key has been updated
host_a_public_key = host_a.get_public_key().serialize()
peerstore_public_key = peerstore.pubkey(peer_id).serialize()
assert host_a_public_key == peerstore_public_key

View File

@ -0,0 +1,215 @@
import pytest
from libp2p.exceptions import ParseError
from libp2p.io.abc import Reader
from libp2p.utils.varint import (
decode_varint_from_bytes,
encode_uvarint,
encode_varint_prefixed,
read_varint_prefixed_bytes,
)
class MockReader(Reader):
"""Mock reader for testing varint functions."""
def __init__(self, data: bytes):
self.data = data
self.position = 0
async def read(self, n: int | None = None) -> bytes:
if self.position >= len(self.data):
return b""
if n is None:
n = len(self.data) - self.position
result = self.data[self.position : self.position + n]
self.position += len(result)
return result
def test_encode_uvarint():
"""Test varint encoding with various values."""
test_cases = [
(0, b"\x00"),
(1, b"\x01"),
(127, b"\x7f"),
(128, b"\x80\x01"),
(255, b"\xff\x01"),
(256, b"\x80\x02"),
(65535, b"\xff\xff\x03"),
(65536, b"\x80\x80\x04"),
(16777215, b"\xff\xff\xff\x07"),
(16777216, b"\x80\x80\x80\x08"),
]
for value, expected in test_cases:
result = encode_uvarint(value)
assert result == expected, (
f"Failed for value {value}: expected {expected.hex()}, got {result.hex()}"
)
def test_decode_varint_from_bytes():
"""Test varint decoding with various values."""
test_cases = [
(b"\x00", 0),
(b"\x01", 1),
(b"\x7f", 127),
(b"\x80\x01", 128),
(b"\xff\x01", 255),
(b"\x80\x02", 256),
(b"\xff\xff\x03", 65535),
(b"\x80\x80\x04", 65536),
(b"\xff\xff\xff\x07", 16777215),
(b"\x80\x80\x80\x08", 16777216),
]
for data, expected in test_cases:
result = decode_varint_from_bytes(data)
assert result == expected, (
f"Failed for data {data.hex()}: expected {expected}, got {result}"
)
def test_decode_varint_from_bytes_invalid():
"""Test varint decoding with invalid data."""
# Empty data
with pytest.raises(ParseError, match="Unexpected end of data"):
decode_varint_from_bytes(b"")
# Incomplete varint (should not raise, but should handle gracefully)
# This depends on the implementation - some might raise, others might return partial
def test_encode_varint_prefixed():
"""Test encoding messages with varint length prefix."""
test_cases = [
(b"", b"\x00"),
(b"hello", b"\x05hello"),
(b"x" * 127, b"\x7f" + b"x" * 127),
(b"x" * 128, b"\x80\x01" + b"x" * 128),
]
for message, expected in test_cases:
result = encode_varint_prefixed(message)
assert result == expected, (
f"Failed for message {message}: expected {expected.hex()}, "
f"got {result.hex()}"
)
@pytest.mark.trio
async def test_read_varint_prefixed_bytes():
"""Test reading length-prefixed bytes from a stream."""
test_cases = [
(b"", b""),
(b"hello", b"hello"),
(b"x" * 127, b"x" * 127),
(b"x" * 128, b"x" * 128),
]
for message, expected in test_cases:
prefixed_data = encode_varint_prefixed(message)
reader = MockReader(prefixed_data)
result = await read_varint_prefixed_bytes(reader)
assert result == expected, (
f"Failed for message {message}: expected {expected}, got {result}"
)
@pytest.mark.trio
async def test_read_varint_prefixed_bytes_incomplete():
"""Test reading length-prefixed bytes with incomplete data."""
from libp2p.io.exceptions import IncompleteReadError
# Test with incomplete varint
reader = MockReader(b"\x80") # Incomplete varint
with pytest.raises(IncompleteReadError):
await read_varint_prefixed_bytes(reader)
# Test with incomplete message
prefixed_data = encode_varint_prefixed(b"hello world")
reader = MockReader(prefixed_data[:-3]) # Missing last 3 bytes
with pytest.raises(IncompleteReadError):
await read_varint_prefixed_bytes(reader)
def test_varint_roundtrip():
"""Test roundtrip encoding and decoding."""
test_values = [0, 1, 127, 128, 255, 256, 65535, 65536, 16777215, 16777216]
for value in test_values:
encoded = encode_uvarint(value)
decoded = decode_varint_from_bytes(encoded)
assert decoded == value, (
f"Roundtrip failed for {value}: encoded={encoded.hex()}, decoded={decoded}"
)
def test_varint_prefixed_roundtrip():
"""Test roundtrip encoding and decoding of length-prefixed messages."""
test_messages = [
b"",
b"hello",
b"x" * 127,
b"x" * 128,
b"x" * 1000,
]
for message in test_messages:
prefixed = encode_varint_prefixed(message)
# Decode the length
length = decode_varint_from_bytes(prefixed)
assert length == len(message), (
f"Length mismatch for {message}: expected {len(message)}, got {length}"
)
# Extract the message
varint_len = 0
for i, byte in enumerate(prefixed):
varint_len += 1
if (byte & 0x80) == 0:
break
extracted_message = prefixed[varint_len:]
assert extracted_message == message, (
f"Message mismatch: expected {message}, got {extracted_message}"
)
def test_large_varint_values():
"""Test varint encoding/decoding with large values."""
large_values = [
2**32 - 1, # 32-bit max
2**64 - 1, # 64-bit max (if supported)
]
for value in large_values:
try:
encoded = encode_uvarint(value)
decoded = decode_varint_from_bytes(encoded)
assert decoded == value, f"Large value roundtrip failed for {value}"
except Exception as e:
# Some implementations might not support very large values
pytest.skip(f"Large value {value} not supported: {e}")
def test_varint_edge_cases():
"""Test varint encoding/decoding with edge cases."""
# Test with maximum 7-bit value
assert encode_uvarint(127) == b"\x7f"
assert decode_varint_from_bytes(b"\x7f") == 127
# Test with minimum 8-bit value
assert encode_uvarint(128) == b"\x80\x01"
assert decode_varint_from_bytes(b"\x80\x01") == 128
# Test with maximum 14-bit value
assert encode_uvarint(16383) == b"\xff\x7f"
assert decode_varint_from_bytes(b"\xff\x7f") == 16383
# Test with minimum 15-bit value
assert encode_uvarint(16384) == b"\x80\x80\x01"
assert decode_varint_from_bytes(b"\x80\x80\x01") == 16384