mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-08 14:10:54 +00:00
feat: add length-prefixed protobuf support to identify protocol
This commit is contained in:
@ -26,5 +26,8 @@ if TYPE_CHECKING:
|
||||
|
||||
def get_default_protocols(host: IHost) -> "OrderedDict[TProtocol, StreamHandlerFn]":
|
||||
return OrderedDict(
|
||||
((IdentifyID, identify_handler_for(host)), (PingID, handle_ping))
|
||||
(
|
||||
(IdentifyID, identify_handler_for(host, use_varint_format=False)),
|
||||
(PingID, handle_ping),
|
||||
)
|
||||
)
|
||||
|
||||
@ -16,7 +16,9 @@ from libp2p.network.stream.exceptions import (
|
||||
StreamClosed,
|
||||
)
|
||||
from libp2p.utils import (
|
||||
decode_varint_with_size,
|
||||
get_agent_version,
|
||||
varint,
|
||||
)
|
||||
|
||||
from .pb.identify_pb2 import (
|
||||
@ -72,7 +74,47 @@ def _mk_identify_protobuf(
|
||||
)
|
||||
|
||||
|
||||
def identify_handler_for(host: IHost) -> StreamHandlerFn:
|
||||
def parse_identify_response(response: bytes) -> Identify:
|
||||
"""
|
||||
Parse identify response that could be either:
|
||||
- Old format: raw protobuf
|
||||
- New format: length-prefixed protobuf
|
||||
|
||||
This function provides backward and forward compatibility.
|
||||
"""
|
||||
# Try new format first: length-prefixed protobuf
|
||||
if len(response) >= 1:
|
||||
length, varint_size = decode_varint_with_size(response)
|
||||
if varint_size > 0 and length > 0 and varint_size + length <= len(response):
|
||||
protobuf_data = response[varint_size : varint_size + length]
|
||||
try:
|
||||
identify_response = Identify()
|
||||
identify_response.ParseFromString(protobuf_data)
|
||||
# Sanity check: must have agent_version (protocol_version is optional)
|
||||
if identify_response.agent_version:
|
||||
logger.debug(
|
||||
"Parsed length-prefixed identify response (new format)"
|
||||
)
|
||||
return identify_response
|
||||
except Exception:
|
||||
pass # Fall through to old format
|
||||
|
||||
# Fall back to old format: raw protobuf
|
||||
try:
|
||||
identify_response = Identify()
|
||||
identify_response.ParseFromString(response)
|
||||
logger.debug("Parsed raw protobuf identify response (old format)")
|
||||
return identify_response
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse identify response: {e}")
|
||||
logger.error(f"Response length: {len(response)}")
|
||||
logger.error(f"Response hex: {response.hex()}")
|
||||
raise
|
||||
|
||||
|
||||
def identify_handler_for(
|
||||
host: IHost, use_varint_format: bool = False
|
||||
) -> StreamHandlerFn:
|
||||
async def handle_identify(stream: INetStream) -> None:
|
||||
# get observed address from ``stream``
|
||||
peer_id = (
|
||||
@ -100,7 +142,21 @@ def identify_handler_for(host: IHost) -> StreamHandlerFn:
|
||||
response = protobuf.SerializeToString()
|
||||
|
||||
try:
|
||||
await stream.write(response)
|
||||
if use_varint_format:
|
||||
# Send length-prefixed protobuf message (new format)
|
||||
await stream.write(varint.encode_uvarint(len(response)))
|
||||
await stream.write(response)
|
||||
logger.debug(
|
||||
"Sent new format (length-prefixed) identify response to %s",
|
||||
peer_id,
|
||||
)
|
||||
else:
|
||||
# Send raw protobuf message (old format for backward compatibility)
|
||||
await stream.write(response)
|
||||
logger.debug(
|
||||
"Sent old format (raw protobuf) identify response to %s",
|
||||
peer_id,
|
||||
)
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to respond to %s request: stream closed", ID)
|
||||
else:
|
||||
|
||||
@ -25,6 +25,10 @@ from libp2p.peer.id import (
|
||||
)
|
||||
from libp2p.utils import (
|
||||
get_agent_version,
|
||||
varint,
|
||||
)
|
||||
from libp2p.utils.varint import (
|
||||
decode_varint_from_bytes,
|
||||
)
|
||||
|
||||
from ..identify.identify import (
|
||||
@ -55,8 +59,29 @@ def identify_push_handler_for(host: IHost) -> StreamHandlerFn:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
try:
|
||||
# Read the identify message from the stream
|
||||
data = await stream.read()
|
||||
# Read length-prefixed identify message from the stream
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
break
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
if not length_bytes:
|
||||
logger.warning("No length prefix received from peer %s", peer_id)
|
||||
return
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
logger.warning("Incomplete message received from peer %s", peer_id)
|
||||
return
|
||||
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
|
||||
@ -159,7 +184,8 @@ async def push_identify_to_peer(
|
||||
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
|
||||
response = identify_msg.SerializeToString()
|
||||
|
||||
# Send the identify message
|
||||
# Send length-prefixed identify message
|
||||
await stream.write(varint.encode_uvarint(len(response)))
|
||||
await stream.write(response)
|
||||
|
||||
# Close the stream
|
||||
|
||||
@ -7,6 +7,8 @@ from libp2p.utils.varint import (
|
||||
encode_varint_prefixed,
|
||||
read_delim,
|
||||
read_varint_prefixed_bytes,
|
||||
decode_varint_from_bytes,
|
||||
decode_varint_with_size,
|
||||
)
|
||||
from libp2p.utils.version import (
|
||||
get_agent_version,
|
||||
@ -20,4 +22,6 @@ __all__ = [
|
||||
"get_agent_version",
|
||||
"read_delim",
|
||||
"read_varint_prefixed_bytes",
|
||||
"decode_varint_from_bytes",
|
||||
"decode_varint_with_size",
|
||||
]
|
||||
|
||||
@ -39,6 +39,30 @@ def encode_uvarint(number: int) -> bytes:
|
||||
return buf
|
||||
|
||||
|
||||
def decode_varint_from_bytes(data: bytes) -> int:
|
||||
"""
|
||||
Decode a varint from bytes and return the value.
|
||||
|
||||
This is a synchronous version of decode_uvarint_from_stream for already-read bytes.
|
||||
"""
|
||||
res = 0
|
||||
for shift in itertools.count(0, 7):
|
||||
if shift > SHIFT_64_BIT_MAX:
|
||||
raise ParseError("Integer is too large...")
|
||||
|
||||
if not data:
|
||||
raise ParseError("Unexpected end of data")
|
||||
|
||||
value = data[0]
|
||||
data = data[1:]
|
||||
|
||||
res += (value & LOW_MASK) << shift
|
||||
|
||||
if not value & HIGH_MASK:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
async def decode_uvarint_from_stream(reader: Reader) -> int:
|
||||
"""https://en.wikipedia.org/wiki/LEB128."""
|
||||
res = 0
|
||||
@ -56,6 +80,33 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
|
||||
return res
|
||||
|
||||
|
||||
def decode_varint_with_size(data: bytes) -> tuple[int, int]:
|
||||
"""
|
||||
Decode a varint from bytes and return (value, bytes_consumed).
|
||||
Returns (0, 0) if the data doesn't start with a valid varint.
|
||||
"""
|
||||
try:
|
||||
# Calculate how many bytes the varint consumes
|
||||
varint_size = 0
|
||||
for i, byte in enumerate(data):
|
||||
varint_size += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
|
||||
if varint_size == 0:
|
||||
return 0, 0
|
||||
|
||||
# Extract just the varint bytes
|
||||
varint_bytes = data[:varint_size]
|
||||
|
||||
# Decode the varint
|
||||
value = decode_varint_from_bytes(varint_bytes)
|
||||
|
||||
return value, varint_size
|
||||
except Exception:
|
||||
return 0, 0
|
||||
|
||||
|
||||
def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
|
||||
varint_len = encode_uvarint(len(msg_bytes))
|
||||
return varint_len + msg_bytes
|
||||
|
||||
Reference in New Issue
Block a user