fix raw format in identify and tests

This commit is contained in:
acul71
2025-07-19 04:11:27 +02:00
parent 7cfe5b9dc7
commit 99db5b309f
3 changed files with 356 additions and 418 deletions

View File

@ -72,13 +72,46 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
format_flag = "--raw-format" if not use_varint_format else ""
print(
f"First host listening (using {format_name} format). "
f"Run this from another console:\n\n"
f"identify-demo "
f"-d {client_addr}\n"
f"identify-demo {format_flag} -d {client_addr}\n"
)
print("Waiting for incoming identify request...")
# Add a custom handler to show connection events
async def custom_identify_handler(stream):
peer_id = stream.muxed_conn.peer_id
print(f"\n🔗 Received identify request from peer: {peer_id}")
# Show remote address in multiaddr format
try:
from libp2p.identity.identify.identify import (
_remote_address_to_multiaddr,
)
remote_address = stream.get_remote_address()
if remote_address:
observed_multiaddr = _remote_address_to_multiaddr(
remote_address
)
# Add the peer ID to create a complete multiaddr
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
print(f" Remote address: {complete_multiaddr}")
else:
print(f" Remote address: {remote_address}")
except Exception:
print(f" Remote address: {stream.get_remote_address()}")
# Call the original handler
await identify_handler(stream)
print(f"✅ Successfully processed identify request from {peer_id}")
# Replace the handler with our custom one
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler)
await trio.sleep_forever()
else:
@ -93,25 +126,99 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
info = info_from_p2p_addr(maddr)
print(f"Second host connecting to peer: {info.peer_id}")
await host_b.connect(info)
try:
await host_b.connect(info)
except Exception as e:
error_msg = str(e)
if "unable to connect" in error_msg or "SwarmException" in error_msg:
print(f"\n❌ Cannot connect to peer: {info.peer_id}")
print(f" Address: {destination}")
print(f" Error: {error_msg}")
print(
"\n💡 Make sure the peer is running and the address is correct."
)
return
else:
# Re-raise other exceptions
raise
stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,))
try:
print("Starting identify protocol...")
# Read the complete response (could be either format)
# Read a larger chunk to get all the data before stream closes
response = await stream.read(8192) # Read enough data in one go
# Read the response properly based on the format
if use_varint_format:
# For length-prefixed format, read varint length first
from libp2p.utils.varint import decode_varint_from_bytes
# Read varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
raise Exception("Stream closed while reading varint length")
length_bytes += b
if b[0] & 0x80 == 0:
break
msg_length = decode_varint_from_bytes(length_bytes)
print(f"Expected message length: {msg_length} bytes")
# Read the protobuf message
response = await stream.read(msg_length)
if len(response) != msg_length:
raise Exception(
f"Incomplete message: expected {msg_length} bytes, "
f"got {len(response)}"
)
# Combine length prefix and message
full_response = length_bytes + response
else:
# For raw format, read all available data
response = await stream.read(8192)
full_response = response
await stream.close()
# Parse the response using the robust protocol-level function
# This handles both old and new formats automatically
identify_msg = parse_identify_response(response)
identify_msg = parse_identify_response(full_response)
print_identify_response(identify_msg)
except Exception as e:
print(f"Identify protocol error: {e}")
error_msg = str(e)
print(f"Identify protocol error: {error_msg}")
# Check for specific format mismatch errors
if "Error parsing message" in error_msg or "DecodeError" in error_msg:
print("\n" + "=" * 60)
print("FORMAT MISMATCH DETECTED!")
print("=" * 60)
if use_varint_format:
print(
"You are using length-prefixed format (default) but the "
"listener"
)
print("is using raw protobuf format.")
print(
"\nTo fix this, run the dialer with the --raw-format flag:"
)
print(f"identify-demo --raw-format -d {destination}")
else:
print("You are using raw protobuf format but the listener")
print("is using length-prefixed format (default).")
print(
"\nTo fix this, run the dialer without the --raw-format "
"flag:"
)
print(f"identify-demo -d {destination}")
print("=" * 60)
else:
import traceback
traceback.print_exc()
return

View File

@ -0,0 +1,241 @@
import logging
import pytest
from libp2p.custom_types import TProtocol
from libp2p.identity.identify.identify import (
AGENT_VERSION,
ID,
PROTOCOL_VERSION,
_multiaddr_to_bytes,
identify_handler_for,
parse_identify_response,
)
from tests.utils.factories import host_pair_factory
logger = logging.getLogger("libp2p.identity.identify-integration-test")
@pytest.mark.trio
async def test_identify_protocol_varint_format_integration(security_protocol):
"""Test identify protocol with varint format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
assert result.listen_addrs == [
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
]
@pytest.mark.trio
async def test_identify_protocol_raw_format_integration(security_protocol):
"""Test identify protocol with raw format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
assert result.listen_addrs == [
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
]
@pytest.mark.trio
async def test_identify_default_format_behavior(security_protocol):
"""Test identify protocol uses correct default format."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Use default identify handler (should use varint format)
host_a.set_stream_handler(ID, identify_handler_for(host_a))
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_cross_format_compatibility_varint_to_raw(security_protocol):
"""Test varint dialer with raw listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Host A uses raw format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
# Host B makes request (will automatically detect format)
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response (should work with automatic format detection)
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_cross_format_compatibility_raw_to_varint(security_protocol):
"""Test raw dialer with varint listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Host A uses varint format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Host B makes request (will automatically detect format)
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response (should work with automatic format detection)
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_format_detection_robustness(security_protocol):
"""Test identify protocol format detection is robust with various message sizes."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Test both formats with different message sizes
for use_varint in [True, False]:
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=use_varint)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_large_message_handling(security_protocol):
"""Test identify protocol handles large messages with many protocols."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add many protocols to make the message larger
async def dummy_handler(stream):
pass
for i in range(10):
host_a.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler)
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_message_equivalence_real_network(security_protocol):
"""Test that both formats produce equivalent messages in real network."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Test varint format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
stream_varint = await host_b.new_stream(host_a.get_id(), (ID,))
response_varint = await stream_varint.read(8192)
await stream_varint.close()
# Test raw format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
stream_raw = await host_b.new_stream(host_a.get_id(), (ID,))
response_raw = await stream_raw.read(8192)
await stream_raw.close()
# Parse both responses
result_varint = parse_identify_response(response_varint)
result_raw = parse_identify_response(response_raw)
# Both should produce identical parsed results
assert result_varint.agent_version == result_raw.agent_version
assert result_varint.protocol_version == result_raw.protocol_version
assert result_varint.public_key == result_raw.public_key
assert result_varint.listen_addrs == result_raw.listen_addrs

View File

@ -1,410 +0,0 @@
import pytest
from libp2p.identity.identify.identify import (
_mk_identify_protobuf,
)
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
from libp2p.io.abc import Closer, Reader, Writer
from libp2p.utils.varint import (
decode_varint_from_bytes,
encode_varint_prefixed,
)
from tests.utils.factories import (
host_pair_factory,
)
class MockStream(Reader, Writer, Closer):
"""Mock stream for testing identify protocol compatibility."""
def __init__(self, data: bytes):
self.data = data
self.position = 0
self.closed = False
async def read(self, n: int | None = None) -> bytes:
if self.closed or self.position >= len(self.data):
return b""
if n is None:
n = len(self.data) - self.position
result = self.data[self.position : self.position + n]
self.position += len(result)
return result
async def write(self, data: bytes) -> None:
# Mock write - just store the data
pass
async def close(self) -> None:
self.closed = True
def create_identify_message(host, observed_multiaddr=None):
"""Create an identify protobuf message."""
return _mk_identify_protobuf(host, observed_multiaddr)
def create_new_format_message(identify_msg):
"""Create a new format (length-prefixed) identify message."""
msg_bytes = identify_msg.SerializeToString()
return encode_varint_prefixed(msg_bytes)
def create_old_format_message(identify_msg):
"""Create an old format (raw protobuf) identify message."""
return identify_msg.SerializeToString()
async def read_new_format_message(stream) -> bytes:
"""Read a new format (length-prefixed) identify message."""
# Read varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if not length_bytes:
raise ValueError("No length prefix received")
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
response = await stream.read(msg_length)
if len(response) != msg_length:
raise ValueError("Incomplete message received")
return response
async def read_old_format_message(stream) -> bytes:
"""Read an old format (raw protobuf) identify message."""
# Read all available data
response = b""
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
async def read_compatible_message(stream) -> bytes:
"""Read an identify message in either old or new format."""
# Try to read a few bytes to detect the format
first_bytes = await stream.read(10)
if not first_bytes:
raise ValueError("No data received")
# Try to decode as varint length prefix (new format)
try:
msg_length = decode_varint_from_bytes(first_bytes)
# Validate that the length is reasonable (not too large)
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate how many bytes the varint consumed
varint_len = 0
for i, byte in enumerate(first_bytes):
varint_len += 1
if (byte & 0x80) == 0:
break
# Read the remaining protobuf message
remaining_bytes = await stream.read(
msg_length - (len(first_bytes) - varint_len)
)
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
message_data = first_bytes[varint_len:] + remaining_bytes
# Try to parse as protobuf to validate
try:
Identify().ParseFromString(message_data)
return message_data
except Exception:
# If protobuf parsing fails, fall back to old format
pass
except Exception:
pass
# Fall back to old format (raw protobuf)
response = first_bytes
# Read more data if available
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
async def read_compatible_message_simple(stream) -> bytes:
"""Read a message in either old or new format (simplified version for testing)."""
# Try to read a few bytes to detect the format
first_bytes = await stream.read(10)
if not first_bytes:
raise ValueError("No data received")
# Try to decode as varint length prefix (new format)
try:
msg_length = decode_varint_from_bytes(first_bytes)
# Validate that the length is reasonable (not too large)
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate how many bytes the varint consumed
varint_len = 0
for i, byte in enumerate(first_bytes):
varint_len += 1
if (byte & 0x80) == 0:
break
# Read the remaining message
remaining_bytes = await stream.read(
msg_length - (len(first_bytes) - varint_len)
)
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
return first_bytes[varint_len:] + remaining_bytes
except Exception:
pass
# Fall back to old format (raw data)
response = first_bytes
# Read more data if available
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
def detect_format(data):
"""Detect if data is in new or old format (varint-prefixed or raw protobuf)."""
if not data:
return "unknown"
# Try to decode as varint
try:
msg_length = decode_varint_from_bytes(data)
# Validate that the length is reasonable
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate varint length
varint_len = 0
for i, byte in enumerate(data):
varint_len += 1
if (byte & 0x80) == 0:
break
# Check if we have enough data for the message
if len(data) >= varint_len + msg_length:
# Additional check: try to parse the message as protobuf
try:
message_data = data[varint_len : varint_len + msg_length]
Identify().ParseFromString(message_data)
return "new"
except Exception:
# If protobuf parsing fails, it's probably not a valid new format
pass
except Exception:
pass
# If varint decoding fails or length is unreasonable, assume old format
return "old"
@pytest.mark.trio
async def test_identify_new_format_compatibility(security_protocol):
"""Test that identify protocol works with new format (length-prefixed) messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create new format message
new_format_data = create_new_format_message(identify_msg)
# Create mock stream with new format data
stream = MockStream(new_format_data)
# Read using new format reader
response = await read_new_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_old_format_compatibility(security_protocol):
"""Test that identify protocol works with old format (raw protobuf) messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create old format message
old_format_data = create_old_format_message(identify_msg)
# Create mock stream with old format data
stream = MockStream(old_format_data)
# Read using old format reader
response = await read_old_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_backward_compatibility_old_format(security_protocol):
"""Test backward compatibility reader with old format messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create old format message
old_format_data = create_old_format_message(identify_msg)
# Create mock stream with old format data
stream = MockStream(old_format_data)
# Read using old format reader (which should work reliably)
response = await read_old_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_backward_compatibility_new_format(security_protocol):
"""Test backward compatibility reader with new format messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create new format message
new_format_data = create_new_format_message(identify_msg)
# Create mock stream with new format data
stream = MockStream(new_format_data)
# Read using new format reader (which should work reliably)
response = await read_new_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_format_detection(security_protocol):
"""Test that the format detection works correctly."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Test new format detection
new_format_data = create_new_format_message(identify_msg)
format_type = detect_format(new_format_data)
assert format_type == "new", "New format should be detected correctly"
# Test old format detection
old_format_data = create_old_format_message(identify_msg)
format_type = detect_format(old_format_data)
assert format_type == "old", "Old format should be detected correctly"
@pytest.mark.trio
async def test_identify_error_handling(security_protocol):
"""Test error handling for malformed messages."""
from libp2p.exceptions import ParseError
# Test with empty data
stream = MockStream(b"")
with pytest.raises(ValueError, match="No data received"):
await read_compatible_message(stream)
# Test with incomplete varint
stream = MockStream(b"\x80") # Incomplete varint
with pytest.raises(ParseError, match="Unexpected end of data"):
await read_new_format_message(stream)
# Test with invalid protobuf data
stream = MockStream(b"\x05invalid") # Length prefix but invalid protobuf
with pytest.raises(Exception): # Should fail when parsing protobuf
response = await read_new_format_message(stream)
Identify().ParseFromString(response)
@pytest.mark.trio
async def test_identify_message_equivalence(security_protocol):
"""Test that old and new format messages are equivalent."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create both formats
new_format_data = create_new_format_message(identify_msg)
old_format_data = create_old_format_message(identify_msg)
# Extract the protobuf message from new format
varint_len = 0
for i, byte in enumerate(new_format_data):
varint_len += 1
if (byte & 0x80) == 0:
break
new_format_protobuf = new_format_data[varint_len:]
# The protobuf messages should be identical
assert new_format_protobuf == old_format_data, (
"Protobuf messages should be identical in both formats"
)