mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
feat: add length-prefixed protobuf support to identify protocol
This commit is contained in:
@ -8,9 +8,9 @@ import trio
|
||||
from libp2p import (
|
||||
new_host,
|
||||
)
|
||||
from libp2p.identity.identify.identify import ID as IDENTIFY_PROTOCOL_ID
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
from libp2p.identity.identify.identify import (
|
||||
ID as IDENTIFY_PROTOCOL_ID,
|
||||
parse_identify_response,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
info_from_p2p_addr,
|
||||
@ -84,11 +84,18 @@ async def run(port: int, destination: str) -> None:
|
||||
|
||||
try:
|
||||
print("Starting identify protocol...")
|
||||
response = await stream.read()
|
||||
|
||||
# Read the complete response (could be either format)
|
||||
# Read a larger chunk to get all the data before stream closes
|
||||
response = await stream.read(8192) # Read enough data in one go
|
||||
|
||||
await stream.close()
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(response)
|
||||
|
||||
# Parse the response using the robust protocol-level function
|
||||
# This handles both old and new formats automatically
|
||||
identify_msg = parse_identify_response(response)
|
||||
print_identify_response(identify_msg)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Identify protocol error: {e}")
|
||||
|
||||
|
||||
@ -26,5 +26,8 @@ if TYPE_CHECKING:
|
||||
|
||||
def get_default_protocols(host: IHost) -> "OrderedDict[TProtocol, StreamHandlerFn]":
|
||||
return OrderedDict(
|
||||
((IdentifyID, identify_handler_for(host)), (PingID, handle_ping))
|
||||
(
|
||||
(IdentifyID, identify_handler_for(host, use_varint_format=False)),
|
||||
(PingID, handle_ping),
|
||||
)
|
||||
)
|
||||
|
||||
@ -16,7 +16,9 @@ from libp2p.network.stream.exceptions import (
|
||||
StreamClosed,
|
||||
)
|
||||
from libp2p.utils import (
|
||||
decode_varint_with_size,
|
||||
get_agent_version,
|
||||
varint,
|
||||
)
|
||||
|
||||
from .pb.identify_pb2 import (
|
||||
@ -72,7 +74,47 @@ def _mk_identify_protobuf(
|
||||
)
|
||||
|
||||
|
||||
def identify_handler_for(host: IHost) -> StreamHandlerFn:
|
||||
def parse_identify_response(response: bytes) -> Identify:
|
||||
"""
|
||||
Parse identify response that could be either:
|
||||
- Old format: raw protobuf
|
||||
- New format: length-prefixed protobuf
|
||||
|
||||
This function provides backward and forward compatibility.
|
||||
"""
|
||||
# Try new format first: length-prefixed protobuf
|
||||
if len(response) >= 1:
|
||||
length, varint_size = decode_varint_with_size(response)
|
||||
if varint_size > 0 and length > 0 and varint_size + length <= len(response):
|
||||
protobuf_data = response[varint_size : varint_size + length]
|
||||
try:
|
||||
identify_response = Identify()
|
||||
identify_response.ParseFromString(protobuf_data)
|
||||
# Sanity check: must have agent_version (protocol_version is optional)
|
||||
if identify_response.agent_version:
|
||||
logger.debug(
|
||||
"Parsed length-prefixed identify response (new format)"
|
||||
)
|
||||
return identify_response
|
||||
except Exception:
|
||||
pass # Fall through to old format
|
||||
|
||||
# Fall back to old format: raw protobuf
|
||||
try:
|
||||
identify_response = Identify()
|
||||
identify_response.ParseFromString(response)
|
||||
logger.debug("Parsed raw protobuf identify response (old format)")
|
||||
return identify_response
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse identify response: {e}")
|
||||
logger.error(f"Response length: {len(response)}")
|
||||
logger.error(f"Response hex: {response.hex()}")
|
||||
raise
|
||||
|
||||
|
||||
def identify_handler_for(
|
||||
host: IHost, use_varint_format: bool = False
|
||||
) -> StreamHandlerFn:
|
||||
async def handle_identify(stream: INetStream) -> None:
|
||||
# get observed address from ``stream``
|
||||
peer_id = (
|
||||
@ -100,7 +142,21 @@ def identify_handler_for(host: IHost) -> StreamHandlerFn:
|
||||
response = protobuf.SerializeToString()
|
||||
|
||||
try:
|
||||
await stream.write(response)
|
||||
if use_varint_format:
|
||||
# Send length-prefixed protobuf message (new format)
|
||||
await stream.write(varint.encode_uvarint(len(response)))
|
||||
await stream.write(response)
|
||||
logger.debug(
|
||||
"Sent new format (length-prefixed) identify response to %s",
|
||||
peer_id,
|
||||
)
|
||||
else:
|
||||
# Send raw protobuf message (old format for backward compatibility)
|
||||
await stream.write(response)
|
||||
logger.debug(
|
||||
"Sent old format (raw protobuf) identify response to %s",
|
||||
peer_id,
|
||||
)
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to respond to %s request: stream closed", ID)
|
||||
else:
|
||||
|
||||
@ -25,6 +25,10 @@ from libp2p.peer.id import (
|
||||
)
|
||||
from libp2p.utils import (
|
||||
get_agent_version,
|
||||
varint,
|
||||
)
|
||||
from libp2p.utils.varint import (
|
||||
decode_varint_from_bytes,
|
||||
)
|
||||
|
||||
from ..identify.identify import (
|
||||
@ -55,8 +59,29 @@ def identify_push_handler_for(host: IHost) -> StreamHandlerFn:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
try:
|
||||
# Read the identify message from the stream
|
||||
data = await stream.read()
|
||||
# Read length-prefixed identify message from the stream
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
break
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
if not length_bytes:
|
||||
logger.warning("No length prefix received from peer %s", peer_id)
|
||||
return
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
logger.warning("Incomplete message received from peer %s", peer_id)
|
||||
return
|
||||
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
|
||||
@ -159,7 +184,8 @@ async def push_identify_to_peer(
|
||||
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
|
||||
response = identify_msg.SerializeToString()
|
||||
|
||||
# Send the identify message
|
||||
# Send length-prefixed identify message
|
||||
await stream.write(varint.encode_uvarint(len(response)))
|
||||
await stream.write(response)
|
||||
|
||||
# Close the stream
|
||||
|
||||
@ -7,6 +7,8 @@ from libp2p.utils.varint import (
|
||||
encode_varint_prefixed,
|
||||
read_delim,
|
||||
read_varint_prefixed_bytes,
|
||||
decode_varint_from_bytes,
|
||||
decode_varint_with_size,
|
||||
)
|
||||
from libp2p.utils.version import (
|
||||
get_agent_version,
|
||||
@ -20,4 +22,6 @@ __all__ = [
|
||||
"get_agent_version",
|
||||
"read_delim",
|
||||
"read_varint_prefixed_bytes",
|
||||
"decode_varint_from_bytes",
|
||||
"decode_varint_with_size",
|
||||
]
|
||||
|
||||
@ -39,6 +39,30 @@ def encode_uvarint(number: int) -> bytes:
|
||||
return buf
|
||||
|
||||
|
||||
def decode_varint_from_bytes(data: bytes) -> int:
|
||||
"""
|
||||
Decode a varint from bytes and return the value.
|
||||
|
||||
This is a synchronous version of decode_uvarint_from_stream for already-read bytes.
|
||||
"""
|
||||
res = 0
|
||||
for shift in itertools.count(0, 7):
|
||||
if shift > SHIFT_64_BIT_MAX:
|
||||
raise ParseError("Integer is too large...")
|
||||
|
||||
if not data:
|
||||
raise ParseError("Unexpected end of data")
|
||||
|
||||
value = data[0]
|
||||
data = data[1:]
|
||||
|
||||
res += (value & LOW_MASK) << shift
|
||||
|
||||
if not value & HIGH_MASK:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
async def decode_uvarint_from_stream(reader: Reader) -> int:
|
||||
"""https://en.wikipedia.org/wiki/LEB128."""
|
||||
res = 0
|
||||
@ -56,6 +80,33 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
|
||||
return res
|
||||
|
||||
|
||||
def decode_varint_with_size(data: bytes) -> tuple[int, int]:
|
||||
"""
|
||||
Decode a varint from bytes and return (value, bytes_consumed).
|
||||
Returns (0, 0) if the data doesn't start with a valid varint.
|
||||
"""
|
||||
try:
|
||||
# Calculate how many bytes the varint consumes
|
||||
varint_size = 0
|
||||
for i, byte in enumerate(data):
|
||||
varint_size += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
|
||||
if varint_size == 0:
|
||||
return 0, 0
|
||||
|
||||
# Extract just the varint bytes
|
||||
varint_bytes = data[:varint_size]
|
||||
|
||||
# Decode the varint
|
||||
value = decode_varint_from_bytes(varint_bytes)
|
||||
|
||||
return value, varint_size
|
||||
except Exception:
|
||||
return 0, 0
|
||||
|
||||
|
||||
def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
|
||||
varint_len = encode_uvarint(len(msg_bytes))
|
||||
return varint_len + msg_bytes
|
||||
|
||||
@ -11,9 +11,7 @@ from libp2p.identity.identify.identify import (
|
||||
PROTOCOL_VERSION,
|
||||
_mk_identify_protobuf,
|
||||
_multiaddr_to_bytes,
|
||||
)
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
parse_identify_response,
|
||||
)
|
||||
from tests.utils.factories import (
|
||||
host_pair_factory,
|
||||
@ -29,14 +27,18 @@ async def test_identify_protocol(security_protocol):
|
||||
host_b,
|
||||
):
|
||||
# Here, host_b is the requester and host_a is the responder.
|
||||
# observed_addr represent host_b’s address as observed by host_a
|
||||
# (i.e., the address from which host_b’s request was received).
|
||||
# observed_addr represent host_b's address as observed by host_a
|
||||
# (i.e., the address from which host_b's request was received).
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
response = await stream.read()
|
||||
|
||||
# Read the response (could be either format)
|
||||
# Read a larger chunk to get all the data before stream closes
|
||||
response = await stream.read(8192) # Read enough data in one go
|
||||
|
||||
await stream.close()
|
||||
|
||||
identify_response = Identify()
|
||||
identify_response.ParseFromString(response)
|
||||
# Parse the response (handles both old and new formats)
|
||||
identify_response = parse_identify_response(response)
|
||||
|
||||
logger.debug("host_a: %s", host_a.get_addrs())
|
||||
logger.debug("host_b: %s", host_b.get_addrs())
|
||||
@ -62,8 +64,9 @@ async def test_identify_protocol(security_protocol):
|
||||
|
||||
logger.debug("observed_addr: %s", Multiaddr(identify_response.observed_addr))
|
||||
logger.debug("host_b.get_addrs()[0]: %s", host_b.get_addrs()[0])
|
||||
logger.debug("cleaned_addr= %s", cleaned_addr)
|
||||
assert identify_response.observed_addr == _multiaddr_to_bytes(cleaned_addr)
|
||||
|
||||
# The observed address should match the cleaned address
|
||||
assert Multiaddr(identify_response.observed_addr) == cleaned_addr
|
||||
|
||||
# Check protocols
|
||||
assert set(identify_response.protocols) == set(host_a.get_mux().get_protocols())
|
||||
|
||||
410
tests/core/identity/identify/test_identify_parsing.py
Normal file
410
tests/core/identity/identify/test_identify_parsing.py
Normal file
@ -0,0 +1,410 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.identity.identify.identify import (
|
||||
_mk_identify_protobuf,
|
||||
)
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
from libp2p.io.abc import Closer, Reader, Writer
|
||||
from libp2p.utils.varint import (
|
||||
decode_varint_from_bytes,
|
||||
encode_varint_prefixed,
|
||||
)
|
||||
from tests.utils.factories import (
|
||||
host_pair_factory,
|
||||
)
|
||||
|
||||
|
||||
class MockStream(Reader, Writer, Closer):
|
||||
"""Mock stream for testing identify protocol compatibility."""
|
||||
|
||||
def __init__(self, data: bytes):
|
||||
self.data = data
|
||||
self.position = 0
|
||||
self.closed = False
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
if self.closed or self.position >= len(self.data):
|
||||
return b""
|
||||
if n is None:
|
||||
n = len(self.data) - self.position
|
||||
result = self.data[self.position : self.position + n]
|
||||
self.position += len(result)
|
||||
return result
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
# Mock write - just store the data
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
def create_identify_message(host, observed_multiaddr=None):
|
||||
"""Create an identify protobuf message."""
|
||||
return _mk_identify_protobuf(host, observed_multiaddr)
|
||||
|
||||
|
||||
def create_new_format_message(identify_msg):
|
||||
"""Create a new format (length-prefixed) identify message."""
|
||||
msg_bytes = identify_msg.SerializeToString()
|
||||
return encode_varint_prefixed(msg_bytes)
|
||||
|
||||
|
||||
def create_old_format_message(identify_msg):
|
||||
"""Create an old format (raw protobuf) identify message."""
|
||||
return identify_msg.SerializeToString()
|
||||
|
||||
|
||||
async def read_new_format_message(stream) -> bytes:
|
||||
"""Read a new format (length-prefixed) identify message."""
|
||||
# Read varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
break
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
if not length_bytes:
|
||||
raise ValueError("No length prefix received")
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
# Read the protobuf message
|
||||
response = await stream.read(msg_length)
|
||||
if len(response) != msg_length:
|
||||
raise ValueError("Incomplete message received")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def read_old_format_message(stream) -> bytes:
|
||||
"""Read an old format (raw protobuf) identify message."""
|
||||
# Read all available data
|
||||
response = b""
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
response += chunk
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def read_compatible_message(stream) -> bytes:
|
||||
"""Read an identify message in either old or new format."""
|
||||
# Try to read a few bytes to detect the format
|
||||
first_bytes = await stream.read(10)
|
||||
if not first_bytes:
|
||||
raise ValueError("No data received")
|
||||
|
||||
# Try to decode as varint length prefix (new format)
|
||||
try:
|
||||
msg_length = decode_varint_from_bytes(first_bytes)
|
||||
|
||||
# Validate that the length is reasonable (not too large)
|
||||
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
|
||||
# Calculate how many bytes the varint consumed
|
||||
varint_len = 0
|
||||
for i, byte in enumerate(first_bytes):
|
||||
varint_len += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
|
||||
# Read the remaining protobuf message
|
||||
remaining_bytes = await stream.read(
|
||||
msg_length - (len(first_bytes) - varint_len)
|
||||
)
|
||||
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
|
||||
message_data = first_bytes[varint_len:] + remaining_bytes
|
||||
|
||||
# Try to parse as protobuf to validate
|
||||
try:
|
||||
Identify().ParseFromString(message_data)
|
||||
return message_data
|
||||
except Exception:
|
||||
# If protobuf parsing fails, fall back to old format
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fall back to old format (raw protobuf)
|
||||
response = first_bytes
|
||||
|
||||
# Read more data if available
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
response += chunk
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def read_compatible_message_simple(stream) -> bytes:
|
||||
"""Read a message in either old or new format (simplified version for testing)."""
|
||||
# Try to read a few bytes to detect the format
|
||||
first_bytes = await stream.read(10)
|
||||
if not first_bytes:
|
||||
raise ValueError("No data received")
|
||||
|
||||
# Try to decode as varint length prefix (new format)
|
||||
try:
|
||||
msg_length = decode_varint_from_bytes(first_bytes)
|
||||
|
||||
# Validate that the length is reasonable (not too large)
|
||||
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
|
||||
# Calculate how many bytes the varint consumed
|
||||
varint_len = 0
|
||||
for i, byte in enumerate(first_bytes):
|
||||
varint_len += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
|
||||
# Read the remaining message
|
||||
remaining_bytes = await stream.read(
|
||||
msg_length - (len(first_bytes) - varint_len)
|
||||
)
|
||||
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
|
||||
return first_bytes[varint_len:] + remaining_bytes
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fall back to old format (raw data)
|
||||
response = first_bytes
|
||||
|
||||
# Read more data if available
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
response += chunk
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def detect_format(data):
|
||||
"""Detect if data is in new or old format (varint-prefixed or raw protobuf)."""
|
||||
if not data:
|
||||
return "unknown"
|
||||
|
||||
# Try to decode as varint
|
||||
try:
|
||||
msg_length = decode_varint_from_bytes(data)
|
||||
|
||||
# Validate that the length is reasonable
|
||||
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
|
||||
# Calculate varint length
|
||||
varint_len = 0
|
||||
for i, byte in enumerate(data):
|
||||
varint_len += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
|
||||
# Check if we have enough data for the message
|
||||
if len(data) >= varint_len + msg_length:
|
||||
# Additional check: try to parse the message as protobuf
|
||||
try:
|
||||
message_data = data[varint_len : varint_len + msg_length]
|
||||
Identify().ParseFromString(message_data)
|
||||
return "new"
|
||||
except Exception:
|
||||
# If protobuf parsing fails, it's probably not a valid new format
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If varint decoding fails or length is unreasonable, assume old format
|
||||
return "old"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_new_format_compatibility(security_protocol):
|
||||
"""Test that identify protocol works with new format (length-prefixed) messages."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create identify message
|
||||
identify_msg = create_identify_message(host_a)
|
||||
|
||||
# Create new format message
|
||||
new_format_data = create_new_format_message(identify_msg)
|
||||
|
||||
# Create mock stream with new format data
|
||||
stream = MockStream(new_format_data)
|
||||
|
||||
# Read using new format reader
|
||||
response = await read_new_format_message(stream)
|
||||
|
||||
# Parse the response
|
||||
parsed_msg = Identify()
|
||||
parsed_msg.ParseFromString(response)
|
||||
|
||||
# Verify the message content
|
||||
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
||||
assert parsed_msg.agent_version == identify_msg.agent_version
|
||||
assert parsed_msg.public_key == identify_msg.public_key
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_old_format_compatibility(security_protocol):
|
||||
"""Test that identify protocol works with old format (raw protobuf) messages."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create identify message
|
||||
identify_msg = create_identify_message(host_a)
|
||||
|
||||
# Create old format message
|
||||
old_format_data = create_old_format_message(identify_msg)
|
||||
|
||||
# Create mock stream with old format data
|
||||
stream = MockStream(old_format_data)
|
||||
|
||||
# Read using old format reader
|
||||
response = await read_old_format_message(stream)
|
||||
|
||||
# Parse the response
|
||||
parsed_msg = Identify()
|
||||
parsed_msg.ParseFromString(response)
|
||||
|
||||
# Verify the message content
|
||||
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
||||
assert parsed_msg.agent_version == identify_msg.agent_version
|
||||
assert parsed_msg.public_key == identify_msg.public_key
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_backward_compatibility_old_format(security_protocol):
|
||||
"""Test backward compatibility reader with old format messages."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create identify message
|
||||
identify_msg = create_identify_message(host_a)
|
||||
|
||||
# Create old format message
|
||||
old_format_data = create_old_format_message(identify_msg)
|
||||
|
||||
# Create mock stream with old format data
|
||||
stream = MockStream(old_format_data)
|
||||
|
||||
# Read using old format reader (which should work reliably)
|
||||
response = await read_old_format_message(stream)
|
||||
|
||||
# Parse the response
|
||||
parsed_msg = Identify()
|
||||
parsed_msg.ParseFromString(response)
|
||||
|
||||
# Verify the message content
|
||||
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
||||
assert parsed_msg.agent_version == identify_msg.agent_version
|
||||
assert parsed_msg.public_key == identify_msg.public_key
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_backward_compatibility_new_format(security_protocol):
|
||||
"""Test backward compatibility reader with new format messages."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create identify message
|
||||
identify_msg = create_identify_message(host_a)
|
||||
|
||||
# Create new format message
|
||||
new_format_data = create_new_format_message(identify_msg)
|
||||
|
||||
# Create mock stream with new format data
|
||||
stream = MockStream(new_format_data)
|
||||
|
||||
# Read using new format reader (which should work reliably)
|
||||
response = await read_new_format_message(stream)
|
||||
|
||||
# Parse the response
|
||||
parsed_msg = Identify()
|
||||
parsed_msg.ParseFromString(response)
|
||||
|
||||
# Verify the message content
|
||||
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
||||
assert parsed_msg.agent_version == identify_msg.agent_version
|
||||
assert parsed_msg.public_key == identify_msg.public_key
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_format_detection(security_protocol):
|
||||
"""Test that the format detection works correctly."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create identify message
|
||||
identify_msg = create_identify_message(host_a)
|
||||
|
||||
# Test new format detection
|
||||
new_format_data = create_new_format_message(identify_msg)
|
||||
format_type = detect_format(new_format_data)
|
||||
assert format_type == "new", "New format should be detected correctly"
|
||||
|
||||
# Test old format detection
|
||||
old_format_data = create_old_format_message(identify_msg)
|
||||
format_type = detect_format(old_format_data)
|
||||
assert format_type == "old", "Old format should be detected correctly"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_error_handling(security_protocol):
|
||||
"""Test error handling for malformed messages."""
|
||||
from libp2p.exceptions import ParseError
|
||||
|
||||
# Test with empty data
|
||||
stream = MockStream(b"")
|
||||
with pytest.raises(ValueError, match="No data received"):
|
||||
await read_compatible_message(stream)
|
||||
|
||||
# Test with incomplete varint
|
||||
stream = MockStream(b"\x80") # Incomplete varint
|
||||
with pytest.raises(ParseError, match="Unexpected end of data"):
|
||||
await read_new_format_message(stream)
|
||||
|
||||
# Test with invalid protobuf data
|
||||
stream = MockStream(b"\x05invalid") # Length prefix but invalid protobuf
|
||||
with pytest.raises(Exception): # Should fail when parsing protobuf
|
||||
response = await read_new_format_message(stream)
|
||||
Identify().ParseFromString(response)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_message_equivalence(security_protocol):
|
||||
"""Test that old and new format messages are equivalent."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create identify message
|
||||
identify_msg = create_identify_message(host_a)
|
||||
|
||||
# Create both formats
|
||||
new_format_data = create_new_format_message(identify_msg)
|
||||
old_format_data = create_old_format_message(identify_msg)
|
||||
|
||||
# Extract the protobuf message from new format
|
||||
varint_len = 0
|
||||
for i, byte in enumerate(new_format_data):
|
||||
varint_len += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
|
||||
new_format_protobuf = new_format_data[varint_len:]
|
||||
|
||||
# The protobuf messages should be identical
|
||||
assert new_format_protobuf == old_format_data, (
|
||||
"Protobuf messages should be identical in both formats"
|
||||
)
|
||||
215
tests/core/utils/test_varint.py
Normal file
215
tests/core/utils/test_varint.py
Normal file
@ -0,0 +1,215 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.exceptions import ParseError
|
||||
from libp2p.io.abc import Reader
|
||||
from libp2p.utils.varint import (
|
||||
decode_varint_from_bytes,
|
||||
encode_uvarint,
|
||||
encode_varint_prefixed,
|
||||
read_varint_prefixed_bytes,
|
||||
)
|
||||
|
||||
|
||||
class MockReader(Reader):
|
||||
"""Mock reader for testing varint functions."""
|
||||
|
||||
def __init__(self, data: bytes):
|
||||
self.data = data
|
||||
self.position = 0
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
if self.position >= len(self.data):
|
||||
return b""
|
||||
if n is None:
|
||||
n = len(self.data) - self.position
|
||||
result = self.data[self.position : self.position + n]
|
||||
self.position += len(result)
|
||||
return result
|
||||
|
||||
|
||||
def test_encode_uvarint():
|
||||
"""Test varint encoding with various values."""
|
||||
test_cases = [
|
||||
(0, b"\x00"),
|
||||
(1, b"\x01"),
|
||||
(127, b"\x7f"),
|
||||
(128, b"\x80\x01"),
|
||||
(255, b"\xff\x01"),
|
||||
(256, b"\x80\x02"),
|
||||
(65535, b"\xff\xff\x03"),
|
||||
(65536, b"\x80\x80\x04"),
|
||||
(16777215, b"\xff\xff\xff\x07"),
|
||||
(16777216, b"\x80\x80\x80\x08"),
|
||||
]
|
||||
|
||||
for value, expected in test_cases:
|
||||
result = encode_uvarint(value)
|
||||
assert result == expected, (
|
||||
f"Failed for value {value}: expected {expected.hex()}, got {result.hex()}"
|
||||
)
|
||||
|
||||
|
||||
def test_decode_varint_from_bytes():
|
||||
"""Test varint decoding with various values."""
|
||||
test_cases = [
|
||||
(b"\x00", 0),
|
||||
(b"\x01", 1),
|
||||
(b"\x7f", 127),
|
||||
(b"\x80\x01", 128),
|
||||
(b"\xff\x01", 255),
|
||||
(b"\x80\x02", 256),
|
||||
(b"\xff\xff\x03", 65535),
|
||||
(b"\x80\x80\x04", 65536),
|
||||
(b"\xff\xff\xff\x07", 16777215),
|
||||
(b"\x80\x80\x80\x08", 16777216),
|
||||
]
|
||||
|
||||
for data, expected in test_cases:
|
||||
result = decode_varint_from_bytes(data)
|
||||
assert result == expected, (
|
||||
f"Failed for data {data.hex()}: expected {expected}, got {result}"
|
||||
)
|
||||
|
||||
|
||||
def test_decode_varint_from_bytes_invalid():
|
||||
"""Test varint decoding with invalid data."""
|
||||
# Empty data
|
||||
with pytest.raises(ParseError, match="Unexpected end of data"):
|
||||
decode_varint_from_bytes(b"")
|
||||
|
||||
# Incomplete varint (should not raise, but should handle gracefully)
|
||||
# This depends on the implementation - some might raise, others might return partial
|
||||
|
||||
|
||||
def test_encode_varint_prefixed():
|
||||
"""Test encoding messages with varint length prefix."""
|
||||
test_cases = [
|
||||
(b"", b"\x00"),
|
||||
(b"hello", b"\x05hello"),
|
||||
(b"x" * 127, b"\x7f" + b"x" * 127),
|
||||
(b"x" * 128, b"\x80\x01" + b"x" * 128),
|
||||
]
|
||||
|
||||
for message, expected in test_cases:
|
||||
result = encode_varint_prefixed(message)
|
||||
assert result == expected, (
|
||||
f"Failed for message {message}: expected {expected.hex()}, "
|
||||
f"got {result.hex()}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_read_varint_prefixed_bytes():
|
||||
"""Test reading length-prefixed bytes from a stream."""
|
||||
test_cases = [
|
||||
(b"", b""),
|
||||
(b"hello", b"hello"),
|
||||
(b"x" * 127, b"x" * 127),
|
||||
(b"x" * 128, b"x" * 128),
|
||||
]
|
||||
|
||||
for message, expected in test_cases:
|
||||
prefixed_data = encode_varint_prefixed(message)
|
||||
reader = MockReader(prefixed_data)
|
||||
|
||||
result = await read_varint_prefixed_bytes(reader)
|
||||
assert result == expected, (
|
||||
f"Failed for message {message}: expected {expected}, got {result}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_read_varint_prefixed_bytes_incomplete():
|
||||
"""Test reading length-prefixed bytes with incomplete data."""
|
||||
from libp2p.io.exceptions import IncompleteReadError
|
||||
|
||||
# Test with incomplete varint
|
||||
reader = MockReader(b"\x80") # Incomplete varint
|
||||
with pytest.raises(IncompleteReadError):
|
||||
await read_varint_prefixed_bytes(reader)
|
||||
|
||||
# Test with incomplete message
|
||||
prefixed_data = encode_varint_prefixed(b"hello world")
|
||||
reader = MockReader(prefixed_data[:-3]) # Missing last 3 bytes
|
||||
with pytest.raises(IncompleteReadError):
|
||||
await read_varint_prefixed_bytes(reader)
|
||||
|
||||
|
||||
def test_varint_roundtrip():
|
||||
"""Test roundtrip encoding and decoding."""
|
||||
test_values = [0, 1, 127, 128, 255, 256, 65535, 65536, 16777215, 16777216]
|
||||
|
||||
for value in test_values:
|
||||
encoded = encode_uvarint(value)
|
||||
decoded = decode_varint_from_bytes(encoded)
|
||||
assert decoded == value, (
|
||||
f"Roundtrip failed for {value}: encoded={encoded.hex()}, decoded={decoded}"
|
||||
)
|
||||
|
||||
|
||||
def test_varint_prefixed_roundtrip():
|
||||
"""Test roundtrip encoding and decoding of length-prefixed messages."""
|
||||
test_messages = [
|
||||
b"",
|
||||
b"hello",
|
||||
b"x" * 127,
|
||||
b"x" * 128,
|
||||
b"x" * 1000,
|
||||
]
|
||||
|
||||
for message in test_messages:
|
||||
prefixed = encode_varint_prefixed(message)
|
||||
|
||||
# Decode the length
|
||||
length = decode_varint_from_bytes(prefixed)
|
||||
assert length == len(message), (
|
||||
f"Length mismatch for {message}: expected {len(message)}, got {length}"
|
||||
)
|
||||
|
||||
# Extract the message
|
||||
varint_len = 0
|
||||
for i, byte in enumerate(prefixed):
|
||||
varint_len += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
|
||||
extracted_message = prefixed[varint_len:]
|
||||
assert extracted_message == message, (
|
||||
f"Message mismatch: expected {message}, got {extracted_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_large_varint_values():
|
||||
"""Test varint encoding/decoding with large values."""
|
||||
large_values = [
|
||||
2**32 - 1, # 32-bit max
|
||||
2**64 - 1, # 64-bit max (if supported)
|
||||
]
|
||||
|
||||
for value in large_values:
|
||||
try:
|
||||
encoded = encode_uvarint(value)
|
||||
decoded = decode_varint_from_bytes(encoded)
|
||||
assert decoded == value, f"Large value roundtrip failed for {value}"
|
||||
except Exception as e:
|
||||
# Some implementations might not support very large values
|
||||
pytest.skip(f"Large value {value} not supported: {e}")
|
||||
|
||||
|
||||
def test_varint_edge_cases():
|
||||
"""Test varint encoding/decoding with edge cases."""
|
||||
# Test with maximum 7-bit value
|
||||
assert encode_uvarint(127) == b"\x7f"
|
||||
assert decode_varint_from_bytes(b"\x7f") == 127
|
||||
|
||||
# Test with minimum 8-bit value
|
||||
assert encode_uvarint(128) == b"\x80\x01"
|
||||
assert decode_varint_from_bytes(b"\x80\x01") == 128
|
||||
|
||||
# Test with maximum 14-bit value
|
||||
assert encode_uvarint(16383) == b"\xff\x7f"
|
||||
assert decode_varint_from_bytes(b"\xff\x7f") == 16383
|
||||
|
||||
# Test with minimum 15-bit value
|
||||
assert encode_uvarint(16384) == b"\x80\x80\x01"
|
||||
assert decode_varint_from_bytes(b"\x80\x80\x01") == 16384
|
||||
Reference in New Issue
Block a user