mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
feat: add length-prefixed protobuf support to identify protocol
This commit is contained in:
@ -8,9 +8,9 @@ import trio
|
|||||||
from libp2p import (
|
from libp2p import (
|
||||||
new_host,
|
new_host,
|
||||||
)
|
)
|
||||||
from libp2p.identity.identify.identify import ID as IDENTIFY_PROTOCOL_ID
|
from libp2p.identity.identify.identify import (
|
||||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
ID as IDENTIFY_PROTOCOL_ID,
|
||||||
Identify,
|
parse_identify_response,
|
||||||
)
|
)
|
||||||
from libp2p.peer.peerinfo import (
|
from libp2p.peer.peerinfo import (
|
||||||
info_from_p2p_addr,
|
info_from_p2p_addr,
|
||||||
@ -84,11 +84,18 @@ async def run(port: int, destination: str) -> None:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
print("Starting identify protocol...")
|
print("Starting identify protocol...")
|
||||||
response = await stream.read()
|
|
||||||
|
# Read the complete response (could be either format)
|
||||||
|
# Read a larger chunk to get all the data before stream closes
|
||||||
|
response = await stream.read(8192) # Read enough data in one go
|
||||||
|
|
||||||
await stream.close()
|
await stream.close()
|
||||||
identify_msg = Identify()
|
|
||||||
identify_msg.ParseFromString(response)
|
# Parse the response using the robust protocol-level function
|
||||||
|
# This handles both old and new formats automatically
|
||||||
|
identify_msg = parse_identify_response(response)
|
||||||
print_identify_response(identify_msg)
|
print_identify_response(identify_msg)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Identify protocol error: {e}")
|
print(f"Identify protocol error: {e}")
|
||||||
|
|
||||||
|
|||||||
@ -26,5 +26,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
def get_default_protocols(host: IHost) -> "OrderedDict[TProtocol, StreamHandlerFn]":
|
def get_default_protocols(host: IHost) -> "OrderedDict[TProtocol, StreamHandlerFn]":
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
((IdentifyID, identify_handler_for(host)), (PingID, handle_ping))
|
(
|
||||||
|
(IdentifyID, identify_handler_for(host, use_varint_format=False)),
|
||||||
|
(PingID, handle_ping),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -16,7 +16,9 @@ from libp2p.network.stream.exceptions import (
|
|||||||
StreamClosed,
|
StreamClosed,
|
||||||
)
|
)
|
||||||
from libp2p.utils import (
|
from libp2p.utils import (
|
||||||
|
decode_varint_with_size,
|
||||||
get_agent_version,
|
get_agent_version,
|
||||||
|
varint,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .pb.identify_pb2 import (
|
from .pb.identify_pb2 import (
|
||||||
@ -72,7 +74,47 @@ def _mk_identify_protobuf(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def identify_handler_for(host: IHost) -> StreamHandlerFn:
|
def parse_identify_response(response: bytes) -> Identify:
|
||||||
|
"""
|
||||||
|
Parse identify response that could be either:
|
||||||
|
- Old format: raw protobuf
|
||||||
|
- New format: length-prefixed protobuf
|
||||||
|
|
||||||
|
This function provides backward and forward compatibility.
|
||||||
|
"""
|
||||||
|
# Try new format first: length-prefixed protobuf
|
||||||
|
if len(response) >= 1:
|
||||||
|
length, varint_size = decode_varint_with_size(response)
|
||||||
|
if varint_size > 0 and length > 0 and varint_size + length <= len(response):
|
||||||
|
protobuf_data = response[varint_size : varint_size + length]
|
||||||
|
try:
|
||||||
|
identify_response = Identify()
|
||||||
|
identify_response.ParseFromString(protobuf_data)
|
||||||
|
# Sanity check: must have agent_version (protocol_version is optional)
|
||||||
|
if identify_response.agent_version:
|
||||||
|
logger.debug(
|
||||||
|
"Parsed length-prefixed identify response (new format)"
|
||||||
|
)
|
||||||
|
return identify_response
|
||||||
|
except Exception:
|
||||||
|
pass # Fall through to old format
|
||||||
|
|
||||||
|
# Fall back to old format: raw protobuf
|
||||||
|
try:
|
||||||
|
identify_response = Identify()
|
||||||
|
identify_response.ParseFromString(response)
|
||||||
|
logger.debug("Parsed raw protobuf identify response (old format)")
|
||||||
|
return identify_response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse identify response: {e}")
|
||||||
|
logger.error(f"Response length: {len(response)}")
|
||||||
|
logger.error(f"Response hex: {response.hex()}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def identify_handler_for(
|
||||||
|
host: IHost, use_varint_format: bool = False
|
||||||
|
) -> StreamHandlerFn:
|
||||||
async def handle_identify(stream: INetStream) -> None:
|
async def handle_identify(stream: INetStream) -> None:
|
||||||
# get observed address from ``stream``
|
# get observed address from ``stream``
|
||||||
peer_id = (
|
peer_id = (
|
||||||
@ -100,7 +142,21 @@ def identify_handler_for(host: IHost) -> StreamHandlerFn:
|
|||||||
response = protobuf.SerializeToString()
|
response = protobuf.SerializeToString()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await stream.write(response)
|
if use_varint_format:
|
||||||
|
# Send length-prefixed protobuf message (new format)
|
||||||
|
await stream.write(varint.encode_uvarint(len(response)))
|
||||||
|
await stream.write(response)
|
||||||
|
logger.debug(
|
||||||
|
"Sent new format (length-prefixed) identify response to %s",
|
||||||
|
peer_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Send raw protobuf message (old format for backward compatibility)
|
||||||
|
await stream.write(response)
|
||||||
|
logger.debug(
|
||||||
|
"Sent old format (raw protobuf) identify response to %s",
|
||||||
|
peer_id,
|
||||||
|
)
|
||||||
except StreamClosed:
|
except StreamClosed:
|
||||||
logger.debug("Fail to respond to %s request: stream closed", ID)
|
logger.debug("Fail to respond to %s request: stream closed", ID)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -25,6 +25,10 @@ from libp2p.peer.id import (
|
|||||||
)
|
)
|
||||||
from libp2p.utils import (
|
from libp2p.utils import (
|
||||||
get_agent_version,
|
get_agent_version,
|
||||||
|
varint,
|
||||||
|
)
|
||||||
|
from libp2p.utils.varint import (
|
||||||
|
decode_varint_from_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..identify.identify import (
|
from ..identify.identify import (
|
||||||
@ -55,8 +59,29 @@ def identify_push_handler_for(host: IHost) -> StreamHandlerFn:
|
|||||||
peer_id = stream.muxed_conn.peer_id
|
peer_id = stream.muxed_conn.peer_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Read the identify message from the stream
|
# Read length-prefixed identify message from the stream
|
||||||
data = await stream.read()
|
# First read the varint length prefix
|
||||||
|
length_bytes = b""
|
||||||
|
while True:
|
||||||
|
b = await stream.read(1)
|
||||||
|
if not b:
|
||||||
|
break
|
||||||
|
length_bytes += b
|
||||||
|
if b[0] & 0x80 == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not length_bytes:
|
||||||
|
logger.warning("No length prefix received from peer %s", peer_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
msg_length = decode_varint_from_bytes(length_bytes)
|
||||||
|
|
||||||
|
# Read the protobuf message
|
||||||
|
data = await stream.read(msg_length)
|
||||||
|
if len(data) != msg_length:
|
||||||
|
logger.warning("Incomplete message received from peer %s", peer_id)
|
||||||
|
return
|
||||||
|
|
||||||
identify_msg = Identify()
|
identify_msg = Identify()
|
||||||
identify_msg.ParseFromString(data)
|
identify_msg.ParseFromString(data)
|
||||||
|
|
||||||
@ -159,7 +184,8 @@ async def push_identify_to_peer(
|
|||||||
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
|
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
|
||||||
response = identify_msg.SerializeToString()
|
response = identify_msg.SerializeToString()
|
||||||
|
|
||||||
# Send the identify message
|
# Send length-prefixed identify message
|
||||||
|
await stream.write(varint.encode_uvarint(len(response)))
|
||||||
await stream.write(response)
|
await stream.write(response)
|
||||||
|
|
||||||
# Close the stream
|
# Close the stream
|
||||||
|
|||||||
@ -7,6 +7,8 @@ from libp2p.utils.varint import (
|
|||||||
encode_varint_prefixed,
|
encode_varint_prefixed,
|
||||||
read_delim,
|
read_delim,
|
||||||
read_varint_prefixed_bytes,
|
read_varint_prefixed_bytes,
|
||||||
|
decode_varint_from_bytes,
|
||||||
|
decode_varint_with_size,
|
||||||
)
|
)
|
||||||
from libp2p.utils.version import (
|
from libp2p.utils.version import (
|
||||||
get_agent_version,
|
get_agent_version,
|
||||||
@ -20,4 +22,6 @@ __all__ = [
|
|||||||
"get_agent_version",
|
"get_agent_version",
|
||||||
"read_delim",
|
"read_delim",
|
||||||
"read_varint_prefixed_bytes",
|
"read_varint_prefixed_bytes",
|
||||||
|
"decode_varint_from_bytes",
|
||||||
|
"decode_varint_with_size",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -39,6 +39,30 @@ def encode_uvarint(number: int) -> bytes:
|
|||||||
return buf
|
return buf
|
||||||
|
|
||||||
|
|
||||||
|
def decode_varint_from_bytes(data: bytes) -> int:
|
||||||
|
"""
|
||||||
|
Decode a varint from bytes and return the value.
|
||||||
|
|
||||||
|
This is a synchronous version of decode_uvarint_from_stream for already-read bytes.
|
||||||
|
"""
|
||||||
|
res = 0
|
||||||
|
for shift in itertools.count(0, 7):
|
||||||
|
if shift > SHIFT_64_BIT_MAX:
|
||||||
|
raise ParseError("Integer is too large...")
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
raise ParseError("Unexpected end of data")
|
||||||
|
|
||||||
|
value = data[0]
|
||||||
|
data = data[1:]
|
||||||
|
|
||||||
|
res += (value & LOW_MASK) << shift
|
||||||
|
|
||||||
|
if not value & HIGH_MASK:
|
||||||
|
break
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
async def decode_uvarint_from_stream(reader: Reader) -> int:
|
async def decode_uvarint_from_stream(reader: Reader) -> int:
|
||||||
"""https://en.wikipedia.org/wiki/LEB128."""
|
"""https://en.wikipedia.org/wiki/LEB128."""
|
||||||
res = 0
|
res = 0
|
||||||
@ -56,6 +80,33 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def decode_varint_with_size(data: bytes) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Decode a varint from bytes and return (value, bytes_consumed).
|
||||||
|
Returns (0, 0) if the data doesn't start with a valid varint.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Calculate how many bytes the varint consumes
|
||||||
|
varint_size = 0
|
||||||
|
for i, byte in enumerate(data):
|
||||||
|
varint_size += 1
|
||||||
|
if (byte & 0x80) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if varint_size == 0:
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
# Extract just the varint bytes
|
||||||
|
varint_bytes = data[:varint_size]
|
||||||
|
|
||||||
|
# Decode the varint
|
||||||
|
value = decode_varint_from_bytes(varint_bytes)
|
||||||
|
|
||||||
|
return value, varint_size
|
||||||
|
except Exception:
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
|
||||||
def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
|
def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
|
||||||
varint_len = encode_uvarint(len(msg_bytes))
|
varint_len = encode_uvarint(len(msg_bytes))
|
||||||
return varint_len + msg_bytes
|
return varint_len + msg_bytes
|
||||||
|
|||||||
@ -11,9 +11,7 @@ from libp2p.identity.identify.identify import (
|
|||||||
PROTOCOL_VERSION,
|
PROTOCOL_VERSION,
|
||||||
_mk_identify_protobuf,
|
_mk_identify_protobuf,
|
||||||
_multiaddr_to_bytes,
|
_multiaddr_to_bytes,
|
||||||
)
|
parse_identify_response,
|
||||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
|
||||||
Identify,
|
|
||||||
)
|
)
|
||||||
from tests.utils.factories import (
|
from tests.utils.factories import (
|
||||||
host_pair_factory,
|
host_pair_factory,
|
||||||
@ -29,14 +27,18 @@ async def test_identify_protocol(security_protocol):
|
|||||||
host_b,
|
host_b,
|
||||||
):
|
):
|
||||||
# Here, host_b is the requester and host_a is the responder.
|
# Here, host_b is the requester and host_a is the responder.
|
||||||
# observed_addr represent host_b’s address as observed by host_a
|
# observed_addr represent host_b's address as observed by host_a
|
||||||
# (i.e., the address from which host_b’s request was received).
|
# (i.e., the address from which host_b's request was received).
|
||||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
response = await stream.read()
|
|
||||||
|
# Read the response (could be either format)
|
||||||
|
# Read a larger chunk to get all the data before stream closes
|
||||||
|
response = await stream.read(8192) # Read enough data in one go
|
||||||
|
|
||||||
await stream.close()
|
await stream.close()
|
||||||
|
|
||||||
identify_response = Identify()
|
# Parse the response (handles both old and new formats)
|
||||||
identify_response.ParseFromString(response)
|
identify_response = parse_identify_response(response)
|
||||||
|
|
||||||
logger.debug("host_a: %s", host_a.get_addrs())
|
logger.debug("host_a: %s", host_a.get_addrs())
|
||||||
logger.debug("host_b: %s", host_b.get_addrs())
|
logger.debug("host_b: %s", host_b.get_addrs())
|
||||||
@ -62,8 +64,9 @@ async def test_identify_protocol(security_protocol):
|
|||||||
|
|
||||||
logger.debug("observed_addr: %s", Multiaddr(identify_response.observed_addr))
|
logger.debug("observed_addr: %s", Multiaddr(identify_response.observed_addr))
|
||||||
logger.debug("host_b.get_addrs()[0]: %s", host_b.get_addrs()[0])
|
logger.debug("host_b.get_addrs()[0]: %s", host_b.get_addrs()[0])
|
||||||
logger.debug("cleaned_addr= %s", cleaned_addr)
|
|
||||||
assert identify_response.observed_addr == _multiaddr_to_bytes(cleaned_addr)
|
# The observed address should match the cleaned address
|
||||||
|
assert Multiaddr(identify_response.observed_addr) == cleaned_addr
|
||||||
|
|
||||||
# Check protocols
|
# Check protocols
|
||||||
assert set(identify_response.protocols) == set(host_a.get_mux().get_protocols())
|
assert set(identify_response.protocols) == set(host_a.get_mux().get_protocols())
|
||||||
|
|||||||
410
tests/core/identity/identify/test_identify_parsing.py
Normal file
410
tests/core/identity/identify/test_identify_parsing.py
Normal file
@ -0,0 +1,410 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.identity.identify.identify import (
|
||||||
|
_mk_identify_protobuf,
|
||||||
|
)
|
||||||
|
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||||
|
Identify,
|
||||||
|
)
|
||||||
|
from libp2p.io.abc import Closer, Reader, Writer
|
||||||
|
from libp2p.utils.varint import (
|
||||||
|
decode_varint_from_bytes,
|
||||||
|
encode_varint_prefixed,
|
||||||
|
)
|
||||||
|
from tests.utils.factories import (
|
||||||
|
host_pair_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockStream(Reader, Writer, Closer):
|
||||||
|
"""Mock stream for testing identify protocol compatibility."""
|
||||||
|
|
||||||
|
def __init__(self, data: bytes):
|
||||||
|
self.data = data
|
||||||
|
self.position = 0
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
async def read(self, n: int | None = None) -> bytes:
|
||||||
|
if self.closed or self.position >= len(self.data):
|
||||||
|
return b""
|
||||||
|
if n is None:
|
||||||
|
n = len(self.data) - self.position
|
||||||
|
result = self.data[self.position : self.position + n]
|
||||||
|
self.position += len(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def write(self, data: bytes) -> None:
|
||||||
|
# Mock write - just store the data
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
|
||||||
|
def create_identify_message(host, observed_multiaddr=None):
|
||||||
|
"""Create an identify protobuf message."""
|
||||||
|
return _mk_identify_protobuf(host, observed_multiaddr)
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_format_message(identify_msg):
|
||||||
|
"""Create a new format (length-prefixed) identify message."""
|
||||||
|
msg_bytes = identify_msg.SerializeToString()
|
||||||
|
return encode_varint_prefixed(msg_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def create_old_format_message(identify_msg):
|
||||||
|
"""Create an old format (raw protobuf) identify message."""
|
||||||
|
return identify_msg.SerializeToString()
|
||||||
|
|
||||||
|
|
||||||
|
async def read_new_format_message(stream) -> bytes:
|
||||||
|
"""Read a new format (length-prefixed) identify message."""
|
||||||
|
# Read varint length prefix
|
||||||
|
length_bytes = b""
|
||||||
|
while True:
|
||||||
|
b = await stream.read(1)
|
||||||
|
if not b:
|
||||||
|
break
|
||||||
|
length_bytes += b
|
||||||
|
if b[0] & 0x80 == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not length_bytes:
|
||||||
|
raise ValueError("No length prefix received")
|
||||||
|
|
||||||
|
msg_length = decode_varint_from_bytes(length_bytes)
|
||||||
|
|
||||||
|
# Read the protobuf message
|
||||||
|
response = await stream.read(msg_length)
|
||||||
|
if len(response) != msg_length:
|
||||||
|
raise ValueError("Incomplete message received")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def read_old_format_message(stream) -> bytes:
|
||||||
|
"""Read an old format (raw protobuf) identify message."""
|
||||||
|
# Read all available data
|
||||||
|
response = b""
|
||||||
|
while True:
|
||||||
|
chunk = await stream.read(4096)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
response += chunk
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def read_compatible_message(stream) -> bytes:
|
||||||
|
"""Read an identify message in either old or new format."""
|
||||||
|
# Try to read a few bytes to detect the format
|
||||||
|
first_bytes = await stream.read(10)
|
||||||
|
if not first_bytes:
|
||||||
|
raise ValueError("No data received")
|
||||||
|
|
||||||
|
# Try to decode as varint length prefix (new format)
|
||||||
|
try:
|
||||||
|
msg_length = decode_varint_from_bytes(first_bytes)
|
||||||
|
|
||||||
|
# Validate that the length is reasonable (not too large)
|
||||||
|
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
|
||||||
|
# Calculate how many bytes the varint consumed
|
||||||
|
varint_len = 0
|
||||||
|
for i, byte in enumerate(first_bytes):
|
||||||
|
varint_len += 1
|
||||||
|
if (byte & 0x80) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Read the remaining protobuf message
|
||||||
|
remaining_bytes = await stream.read(
|
||||||
|
msg_length - (len(first_bytes) - varint_len)
|
||||||
|
)
|
||||||
|
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
|
||||||
|
message_data = first_bytes[varint_len:] + remaining_bytes
|
||||||
|
|
||||||
|
# Try to parse as protobuf to validate
|
||||||
|
try:
|
||||||
|
Identify().ParseFromString(message_data)
|
||||||
|
return message_data
|
||||||
|
except Exception:
|
||||||
|
# If protobuf parsing fails, fall back to old format
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fall back to old format (raw protobuf)
|
||||||
|
response = first_bytes
|
||||||
|
|
||||||
|
# Read more data if available
|
||||||
|
while True:
|
||||||
|
chunk = await stream.read(4096)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
response += chunk
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def read_compatible_message_simple(stream) -> bytes:
|
||||||
|
"""Read a message in either old or new format (simplified version for testing)."""
|
||||||
|
# Try to read a few bytes to detect the format
|
||||||
|
first_bytes = await stream.read(10)
|
||||||
|
if not first_bytes:
|
||||||
|
raise ValueError("No data received")
|
||||||
|
|
||||||
|
# Try to decode as varint length prefix (new format)
|
||||||
|
try:
|
||||||
|
msg_length = decode_varint_from_bytes(first_bytes)
|
||||||
|
|
||||||
|
# Validate that the length is reasonable (not too large)
|
||||||
|
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
|
||||||
|
# Calculate how many bytes the varint consumed
|
||||||
|
varint_len = 0
|
||||||
|
for i, byte in enumerate(first_bytes):
|
||||||
|
varint_len += 1
|
||||||
|
if (byte & 0x80) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Read the remaining message
|
||||||
|
remaining_bytes = await stream.read(
|
||||||
|
msg_length - (len(first_bytes) - varint_len)
|
||||||
|
)
|
||||||
|
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
|
||||||
|
return first_bytes[varint_len:] + remaining_bytes
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fall back to old format (raw data)
|
||||||
|
response = first_bytes
|
||||||
|
|
||||||
|
# Read more data if available
|
||||||
|
while True:
|
||||||
|
chunk = await stream.read(4096)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
response += chunk
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def detect_format(data):
|
||||||
|
"""Detect if data is in new or old format (varint-prefixed or raw protobuf)."""
|
||||||
|
if not data:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
# Try to decode as varint
|
||||||
|
try:
|
||||||
|
msg_length = decode_varint_from_bytes(data)
|
||||||
|
|
||||||
|
# Validate that the length is reasonable
|
||||||
|
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
|
||||||
|
# Calculate varint length
|
||||||
|
varint_len = 0
|
||||||
|
for i, byte in enumerate(data):
|
||||||
|
varint_len += 1
|
||||||
|
if (byte & 0x80) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check if we have enough data for the message
|
||||||
|
if len(data) >= varint_len + msg_length:
|
||||||
|
# Additional check: try to parse the message as protobuf
|
||||||
|
try:
|
||||||
|
message_data = data[varint_len : varint_len + msg_length]
|
||||||
|
Identify().ParseFromString(message_data)
|
||||||
|
return "new"
|
||||||
|
except Exception:
|
||||||
|
# If protobuf parsing fails, it's probably not a valid new format
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If varint decoding fails or length is unreasonable, assume old format
|
||||||
|
return "old"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_new_format_compatibility(security_protocol):
|
||||||
|
"""Test that identify protocol works with new format (length-prefixed) messages."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
# Create identify message
|
||||||
|
identify_msg = create_identify_message(host_a)
|
||||||
|
|
||||||
|
# Create new format message
|
||||||
|
new_format_data = create_new_format_message(identify_msg)
|
||||||
|
|
||||||
|
# Create mock stream with new format data
|
||||||
|
stream = MockStream(new_format_data)
|
||||||
|
|
||||||
|
# Read using new format reader
|
||||||
|
response = await read_new_format_message(stream)
|
||||||
|
|
||||||
|
# Parse the response
|
||||||
|
parsed_msg = Identify()
|
||||||
|
parsed_msg.ParseFromString(response)
|
||||||
|
|
||||||
|
# Verify the message content
|
||||||
|
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
||||||
|
assert parsed_msg.agent_version == identify_msg.agent_version
|
||||||
|
assert parsed_msg.public_key == identify_msg.public_key
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_old_format_compatibility(security_protocol):
|
||||||
|
"""Test that identify protocol works with old format (raw protobuf) messages."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
# Create identify message
|
||||||
|
identify_msg = create_identify_message(host_a)
|
||||||
|
|
||||||
|
# Create old format message
|
||||||
|
old_format_data = create_old_format_message(identify_msg)
|
||||||
|
|
||||||
|
# Create mock stream with old format data
|
||||||
|
stream = MockStream(old_format_data)
|
||||||
|
|
||||||
|
# Read using old format reader
|
||||||
|
response = await read_old_format_message(stream)
|
||||||
|
|
||||||
|
# Parse the response
|
||||||
|
parsed_msg = Identify()
|
||||||
|
parsed_msg.ParseFromString(response)
|
||||||
|
|
||||||
|
# Verify the message content
|
||||||
|
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
||||||
|
assert parsed_msg.agent_version == identify_msg.agent_version
|
||||||
|
assert parsed_msg.public_key == identify_msg.public_key
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_backward_compatibility_old_format(security_protocol):
|
||||||
|
"""Test backward compatibility reader with old format messages."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
# Create identify message
|
||||||
|
identify_msg = create_identify_message(host_a)
|
||||||
|
|
||||||
|
# Create old format message
|
||||||
|
old_format_data = create_old_format_message(identify_msg)
|
||||||
|
|
||||||
|
# Create mock stream with old format data
|
||||||
|
stream = MockStream(old_format_data)
|
||||||
|
|
||||||
|
# Read using old format reader (which should work reliably)
|
||||||
|
response = await read_old_format_message(stream)
|
||||||
|
|
||||||
|
# Parse the response
|
||||||
|
parsed_msg = Identify()
|
||||||
|
parsed_msg.ParseFromString(response)
|
||||||
|
|
||||||
|
# Verify the message content
|
||||||
|
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
||||||
|
assert parsed_msg.agent_version == identify_msg.agent_version
|
||||||
|
assert parsed_msg.public_key == identify_msg.public_key
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_backward_compatibility_new_format(security_protocol):
|
||||||
|
"""Test backward compatibility reader with new format messages."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
# Create identify message
|
||||||
|
identify_msg = create_identify_message(host_a)
|
||||||
|
|
||||||
|
# Create new format message
|
||||||
|
new_format_data = create_new_format_message(identify_msg)
|
||||||
|
|
||||||
|
# Create mock stream with new format data
|
||||||
|
stream = MockStream(new_format_data)
|
||||||
|
|
||||||
|
# Read using new format reader (which should work reliably)
|
||||||
|
response = await read_new_format_message(stream)
|
||||||
|
|
||||||
|
# Parse the response
|
||||||
|
parsed_msg = Identify()
|
||||||
|
parsed_msg.ParseFromString(response)
|
||||||
|
|
||||||
|
# Verify the message content
|
||||||
|
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
||||||
|
assert parsed_msg.agent_version == identify_msg.agent_version
|
||||||
|
assert parsed_msg.public_key == identify_msg.public_key
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_format_detection(security_protocol):
|
||||||
|
"""Test that the format detection works correctly."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
# Create identify message
|
||||||
|
identify_msg = create_identify_message(host_a)
|
||||||
|
|
||||||
|
# Test new format detection
|
||||||
|
new_format_data = create_new_format_message(identify_msg)
|
||||||
|
format_type = detect_format(new_format_data)
|
||||||
|
assert format_type == "new", "New format should be detected correctly"
|
||||||
|
|
||||||
|
# Test old format detection
|
||||||
|
old_format_data = create_old_format_message(identify_msg)
|
||||||
|
format_type = detect_format(old_format_data)
|
||||||
|
assert format_type == "old", "Old format should be detected correctly"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_error_handling(security_protocol):
|
||||||
|
"""Test error handling for malformed messages."""
|
||||||
|
from libp2p.exceptions import ParseError
|
||||||
|
|
||||||
|
# Test with empty data
|
||||||
|
stream = MockStream(b"")
|
||||||
|
with pytest.raises(ValueError, match="No data received"):
|
||||||
|
await read_compatible_message(stream)
|
||||||
|
|
||||||
|
# Test with incomplete varint
|
||||||
|
stream = MockStream(b"\x80") # Incomplete varint
|
||||||
|
with pytest.raises(ParseError, match="Unexpected end of data"):
|
||||||
|
await read_new_format_message(stream)
|
||||||
|
|
||||||
|
# Test with invalid protobuf data
|
||||||
|
stream = MockStream(b"\x05invalid") # Length prefix but invalid protobuf
|
||||||
|
with pytest.raises(Exception): # Should fail when parsing protobuf
|
||||||
|
response = await read_new_format_message(stream)
|
||||||
|
Identify().ParseFromString(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_message_equivalence(security_protocol):
|
||||||
|
"""Test that old and new format messages are equivalent."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
# Create identify message
|
||||||
|
identify_msg = create_identify_message(host_a)
|
||||||
|
|
||||||
|
# Create both formats
|
||||||
|
new_format_data = create_new_format_message(identify_msg)
|
||||||
|
old_format_data = create_old_format_message(identify_msg)
|
||||||
|
|
||||||
|
# Extract the protobuf message from new format
|
||||||
|
varint_len = 0
|
||||||
|
for i, byte in enumerate(new_format_data):
|
||||||
|
varint_len += 1
|
||||||
|
if (byte & 0x80) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
new_format_protobuf = new_format_data[varint_len:]
|
||||||
|
|
||||||
|
# The protobuf messages should be identical
|
||||||
|
assert new_format_protobuf == old_format_data, (
|
||||||
|
"Protobuf messages should be identical in both formats"
|
||||||
|
)
|
||||||
215
tests/core/utils/test_varint.py
Normal file
215
tests/core/utils/test_varint.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.exceptions import ParseError
|
||||||
|
from libp2p.io.abc import Reader
|
||||||
|
from libp2p.utils.varint import (
|
||||||
|
decode_varint_from_bytes,
|
||||||
|
encode_uvarint,
|
||||||
|
encode_varint_prefixed,
|
||||||
|
read_varint_prefixed_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockReader(Reader):
|
||||||
|
"""Mock reader for testing varint functions."""
|
||||||
|
|
||||||
|
def __init__(self, data: bytes):
|
||||||
|
self.data = data
|
||||||
|
self.position = 0
|
||||||
|
|
||||||
|
async def read(self, n: int | None = None) -> bytes:
|
||||||
|
if self.position >= len(self.data):
|
||||||
|
return b""
|
||||||
|
if n is None:
|
||||||
|
n = len(self.data) - self.position
|
||||||
|
result = self.data[self.position : self.position + n]
|
||||||
|
self.position += len(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_uvarint():
|
||||||
|
"""Test varint encoding with various values."""
|
||||||
|
test_cases = [
|
||||||
|
(0, b"\x00"),
|
||||||
|
(1, b"\x01"),
|
||||||
|
(127, b"\x7f"),
|
||||||
|
(128, b"\x80\x01"),
|
||||||
|
(255, b"\xff\x01"),
|
||||||
|
(256, b"\x80\x02"),
|
||||||
|
(65535, b"\xff\xff\x03"),
|
||||||
|
(65536, b"\x80\x80\x04"),
|
||||||
|
(16777215, b"\xff\xff\xff\x07"),
|
||||||
|
(16777216, b"\x80\x80\x80\x08"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for value, expected in test_cases:
|
||||||
|
result = encode_uvarint(value)
|
||||||
|
assert result == expected, (
|
||||||
|
f"Failed for value {value}: expected {expected.hex()}, got {result.hex()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_varint_from_bytes():
|
||||||
|
"""Test varint decoding with various values."""
|
||||||
|
test_cases = [
|
||||||
|
(b"\x00", 0),
|
||||||
|
(b"\x01", 1),
|
||||||
|
(b"\x7f", 127),
|
||||||
|
(b"\x80\x01", 128),
|
||||||
|
(b"\xff\x01", 255),
|
||||||
|
(b"\x80\x02", 256),
|
||||||
|
(b"\xff\xff\x03", 65535),
|
||||||
|
(b"\x80\x80\x04", 65536),
|
||||||
|
(b"\xff\xff\xff\x07", 16777215),
|
||||||
|
(b"\x80\x80\x80\x08", 16777216),
|
||||||
|
]
|
||||||
|
|
||||||
|
for data, expected in test_cases:
|
||||||
|
result = decode_varint_from_bytes(data)
|
||||||
|
assert result == expected, (
|
||||||
|
f"Failed for data {data.hex()}: expected {expected}, got {result}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_varint_from_bytes_invalid():
|
||||||
|
"""Test varint decoding with invalid data."""
|
||||||
|
# Empty data
|
||||||
|
with pytest.raises(ParseError, match="Unexpected end of data"):
|
||||||
|
decode_varint_from_bytes(b"")
|
||||||
|
|
||||||
|
# Incomplete varint (should not raise, but should handle gracefully)
|
||||||
|
# This depends on the implementation - some might raise, others might return partial
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_varint_prefixed():
|
||||||
|
"""Test encoding messages with varint length prefix."""
|
||||||
|
test_cases = [
|
||||||
|
(b"", b"\x00"),
|
||||||
|
(b"hello", b"\x05hello"),
|
||||||
|
(b"x" * 127, b"\x7f" + b"x" * 127),
|
||||||
|
(b"x" * 128, b"\x80\x01" + b"x" * 128),
|
||||||
|
]
|
||||||
|
|
||||||
|
for message, expected in test_cases:
|
||||||
|
result = encode_varint_prefixed(message)
|
||||||
|
assert result == expected, (
|
||||||
|
f"Failed for message {message}: expected {expected.hex()}, "
|
||||||
|
f"got {result.hex()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_read_varint_prefixed_bytes():
|
||||||
|
"""Test reading length-prefixed bytes from a stream."""
|
||||||
|
test_cases = [
|
||||||
|
(b"", b""),
|
||||||
|
(b"hello", b"hello"),
|
||||||
|
(b"x" * 127, b"x" * 127),
|
||||||
|
(b"x" * 128, b"x" * 128),
|
||||||
|
]
|
||||||
|
|
||||||
|
for message, expected in test_cases:
|
||||||
|
prefixed_data = encode_varint_prefixed(message)
|
||||||
|
reader = MockReader(prefixed_data)
|
||||||
|
|
||||||
|
result = await read_varint_prefixed_bytes(reader)
|
||||||
|
assert result == expected, (
|
||||||
|
f"Failed for message {message}: expected {expected}, got {result}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_read_varint_prefixed_bytes_incomplete():
|
||||||
|
"""Test reading length-prefixed bytes with incomplete data."""
|
||||||
|
from libp2p.io.exceptions import IncompleteReadError
|
||||||
|
|
||||||
|
# Test with incomplete varint
|
||||||
|
reader = MockReader(b"\x80") # Incomplete varint
|
||||||
|
with pytest.raises(IncompleteReadError):
|
||||||
|
await read_varint_prefixed_bytes(reader)
|
||||||
|
|
||||||
|
# Test with incomplete message
|
||||||
|
prefixed_data = encode_varint_prefixed(b"hello world")
|
||||||
|
reader = MockReader(prefixed_data[:-3]) # Missing last 3 bytes
|
||||||
|
with pytest.raises(IncompleteReadError):
|
||||||
|
await read_varint_prefixed_bytes(reader)
|
||||||
|
|
||||||
|
|
||||||
|
def test_varint_roundtrip():
|
||||||
|
"""Test roundtrip encoding and decoding."""
|
||||||
|
test_values = [0, 1, 127, 128, 255, 256, 65535, 65536, 16777215, 16777216]
|
||||||
|
|
||||||
|
for value in test_values:
|
||||||
|
encoded = encode_uvarint(value)
|
||||||
|
decoded = decode_varint_from_bytes(encoded)
|
||||||
|
assert decoded == value, (
|
||||||
|
f"Roundtrip failed for {value}: encoded={encoded.hex()}, decoded={decoded}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_varint_prefixed_roundtrip():
|
||||||
|
"""Test roundtrip encoding and decoding of length-prefixed messages."""
|
||||||
|
test_messages = [
|
||||||
|
b"",
|
||||||
|
b"hello",
|
||||||
|
b"x" * 127,
|
||||||
|
b"x" * 128,
|
||||||
|
b"x" * 1000,
|
||||||
|
]
|
||||||
|
|
||||||
|
for message in test_messages:
|
||||||
|
prefixed = encode_varint_prefixed(message)
|
||||||
|
|
||||||
|
# Decode the length
|
||||||
|
length = decode_varint_from_bytes(prefixed)
|
||||||
|
assert length == len(message), (
|
||||||
|
f"Length mismatch for {message}: expected {len(message)}, got {length}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the message
|
||||||
|
varint_len = 0
|
||||||
|
for i, byte in enumerate(prefixed):
|
||||||
|
varint_len += 1
|
||||||
|
if (byte & 0x80) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
extracted_message = prefixed[varint_len:]
|
||||||
|
assert extracted_message == message, (
|
||||||
|
f"Message mismatch: expected {message}, got {extracted_message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_large_varint_values():
|
||||||
|
"""Test varint encoding/decoding with large values."""
|
||||||
|
large_values = [
|
||||||
|
2**32 - 1, # 32-bit max
|
||||||
|
2**64 - 1, # 64-bit max (if supported)
|
||||||
|
]
|
||||||
|
|
||||||
|
for value in large_values:
|
||||||
|
try:
|
||||||
|
encoded = encode_uvarint(value)
|
||||||
|
decoded = decode_varint_from_bytes(encoded)
|
||||||
|
assert decoded == value, f"Large value roundtrip failed for {value}"
|
||||||
|
except Exception as e:
|
||||||
|
# Some implementations might not support very large values
|
||||||
|
pytest.skip(f"Large value {value} not supported: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_varint_edge_cases():
|
||||||
|
"""Test varint encoding/decoding with edge cases."""
|
||||||
|
# Test with maximum 7-bit value
|
||||||
|
assert encode_uvarint(127) == b"\x7f"
|
||||||
|
assert decode_varint_from_bytes(b"\x7f") == 127
|
||||||
|
|
||||||
|
# Test with minimum 8-bit value
|
||||||
|
assert encode_uvarint(128) == b"\x80\x01"
|
||||||
|
assert decode_varint_from_bytes(b"\x80\x01") == 128
|
||||||
|
|
||||||
|
# Test with maximum 14-bit value
|
||||||
|
assert encode_uvarint(16383) == b"\xff\x7f"
|
||||||
|
assert decode_varint_from_bytes(b"\xff\x7f") == 16383
|
||||||
|
|
||||||
|
# Test with minimum 15-bit value
|
||||||
|
assert encode_uvarint(16384) == b"\x80\x80\x01"
|
||||||
|
assert decode_varint_from_bytes(b"\x80\x80\x01") == 16384
|
||||||
Reference in New Issue
Block a user