mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
- Add comprehensive integration tests for identify-push protocol - Support both raw protobuf and varint message formats - Improve yamux logging integration with LIBP2P_DEBUG - Fix RawConnError handling to reduce log noise - Add Ctrl+C handling to identify examples - Enhance identify-push listener/dialer demo Fixes: #784
233 lines
6.2 KiB
Python
233 lines
6.2 KiB
Python
import itertools
|
|
import logging
|
|
import math
|
|
from typing import BinaryIO
|
|
|
|
from libp2p.abc import INetStream
|
|
from libp2p.exceptions import (
|
|
ParseError,
|
|
)
|
|
from libp2p.io.abc import (
|
|
Reader,
|
|
)
|
|
from libp2p.io.utils import (
|
|
read_exactly,
|
|
)
|
|
|
|
logger = logging.getLogger("libp2p.utils.varint")
|
|
|
|
# Unsigned LEB128(varint codec)
|
|
# Reference: https://github.com/ethereum/py-wasm/blob/master/wasm/parsers/leb128.py
|
|
|
|
LOW_MASK = 2**7 - 1
|
|
HIGH_MASK = 2**7
|
|
|
|
# The maximum shift width for a 64 bit integer. We shouldn't have to decode
|
|
# integers larger than this.
|
|
SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7
|
|
|
|
|
|
def encode_uvarint(value: int) -> bytes:
|
|
"""Encode an unsigned integer as a varint."""
|
|
if value < 0:
|
|
raise ValueError("Cannot encode negative value as uvarint")
|
|
|
|
result = bytearray()
|
|
while value >= 0x80:
|
|
result.append((value & 0x7F) | 0x80)
|
|
value >>= 7
|
|
result.append(value & 0x7F)
|
|
return bytes(result)
|
|
|
|
|
|
def decode_uvarint(data: bytes) -> int:
|
|
"""Decode a varint from bytes."""
|
|
if not data:
|
|
raise ParseError("Unexpected end of data")
|
|
|
|
result = 0
|
|
shift = 0
|
|
|
|
for byte in data:
|
|
result |= (byte & 0x7F) << shift
|
|
if (byte & 0x80) == 0:
|
|
break
|
|
shift += 7
|
|
if shift >= 64:
|
|
raise ValueError("Varint too long")
|
|
|
|
return result
|
|
|
|
|
|
def decode_varint_from_bytes(data: bytes) -> int:
|
|
"""Decode a varint from bytes (alias for decode_uvarint for backward comp)."""
|
|
return decode_uvarint(data)
|
|
|
|
|
|
async def decode_uvarint_from_stream(reader: Reader) -> int:
|
|
"""https://en.wikipedia.org/wiki/LEB128."""
|
|
res = 0
|
|
for shift in itertools.count(0, 7):
|
|
if shift > SHIFT_64_BIT_MAX:
|
|
raise ParseError(
|
|
"Varint decoding error: integer exceeds maximum size of 64 bits."
|
|
)
|
|
|
|
byte = await read_exactly(reader, 1)
|
|
value = byte[0]
|
|
|
|
res += (value & LOW_MASK) << shift
|
|
|
|
if not value & HIGH_MASK:
|
|
break
|
|
return res
|
|
|
|
|
|
def decode_varint_with_size(data: bytes) -> tuple[int, int]:
|
|
"""
|
|
Decode a varint from bytes and return both the value and the number of bytes
|
|
consumed.
|
|
|
|
Returns:
|
|
Tuple[int, int]: (value, bytes_consumed)
|
|
|
|
"""
|
|
result = 0
|
|
shift = 0
|
|
bytes_consumed = 0
|
|
|
|
for byte in data:
|
|
result |= (byte & 0x7F) << shift
|
|
bytes_consumed += 1
|
|
if (byte & 0x80) == 0:
|
|
break
|
|
shift += 7
|
|
if shift >= 64:
|
|
raise ValueError("Varint too long")
|
|
|
|
return result, bytes_consumed
|
|
|
|
|
|
def encode_varint_prefixed(data: bytes) -> bytes:
|
|
"""Encode data with a varint length prefix."""
|
|
length_bytes = encode_uvarint(len(data))
|
|
return length_bytes + data
|
|
|
|
|
|
async def read_varint_prefixed_bytes(reader: Reader) -> bytes:
|
|
len_msg = await decode_uvarint_from_stream(reader)
|
|
data = await read_exactly(reader, len_msg)
|
|
return data
|
|
|
|
|
|
# Delimited read/write, used by multistream-select.
|
|
# Reference: https://github.com/gogo/protobuf/blob/07eab6a8298cf32fac45cceaac59424f98421bbc/io/varint.go#L109-L126 # noqa: E501
|
|
|
|
|
|
def encode_delim(msg: bytes) -> bytes:
|
|
delimited_msg = msg + b"\n"
|
|
return encode_varint_prefixed(delimited_msg)
|
|
|
|
|
|
async def read_delim(reader: Reader) -> bytes:
|
|
msg_bytes = await read_varint_prefixed_bytes(reader)
|
|
if len(msg_bytes) == 0:
|
|
raise ParseError("`len(msg_bytes)` should not be 0")
|
|
if msg_bytes[-1:] != b"\n":
|
|
raise ParseError(
|
|
f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes!r}'
|
|
)
|
|
return msg_bytes[:-1]
|
|
|
|
|
|
def read_varint_prefixed_bytes_sync(
|
|
stream: BinaryIO, max_length: int = 1024 * 1024
|
|
) -> bytes:
|
|
"""
|
|
Read varint-prefixed bytes from a stream.
|
|
|
|
Args:
|
|
stream: A stream-like object with a read() method
|
|
max_length: Maximum allowed data length to prevent memory exhaustion
|
|
|
|
Returns:
|
|
bytes: The data without the length prefix
|
|
|
|
Raises:
|
|
ValueError: If the length prefix is invalid or too large
|
|
EOFError: If the stream ends unexpectedly
|
|
|
|
"""
|
|
# Read the varint length prefix
|
|
length_bytes = b""
|
|
while True:
|
|
byte_data = stream.read(1)
|
|
if not byte_data:
|
|
raise EOFError("Stream ended while reading varint length prefix")
|
|
|
|
length_bytes += byte_data
|
|
if byte_data[0] & 0x80 == 0:
|
|
break
|
|
|
|
# Decode the length
|
|
length = decode_uvarint(length_bytes)
|
|
|
|
if length > max_length:
|
|
raise ValueError(f"Data length {length} exceeds maximum allowed {max_length}")
|
|
|
|
# Read the data
|
|
data = stream.read(length)
|
|
if len(data) != length:
|
|
raise EOFError(f"Expected {length} bytes, got {len(data)}")
|
|
|
|
return data
|
|
|
|
|
|
async def read_length_prefixed_protobuf(
|
|
stream: INetStream, use_varint_format: bool = True, max_length: int = 1024 * 1024
|
|
) -> bytes:
|
|
"""Read a protobuf message from a stream, handling both formats."""
|
|
if use_varint_format:
|
|
# Read length-prefixed protobuf message from the stream
|
|
# First read the varint length prefix
|
|
length_bytes = b""
|
|
while True:
|
|
b = await stream.read(1)
|
|
if not b:
|
|
raise Exception("No length prefix received")
|
|
|
|
length_bytes += b
|
|
if b[0] & 0x80 == 0:
|
|
break
|
|
|
|
msg_length = decode_varint_from_bytes(length_bytes)
|
|
|
|
if msg_length > max_length:
|
|
raise Exception(
|
|
f"Message length {msg_length} exceeds maximum allowed {max_length}"
|
|
)
|
|
|
|
# Read the protobuf message
|
|
data = await stream.read(msg_length)
|
|
if len(data) != msg_length:
|
|
raise Exception(
|
|
f"Incomplete message: expected {msg_length}, got {len(data)}"
|
|
)
|
|
|
|
return data
|
|
else:
|
|
# Read raw protobuf message from the stream
|
|
# For raw format, read all available data in one go
|
|
data = await stream.read()
|
|
|
|
# If we got no data, raise an exception
|
|
if not data:
|
|
raise Exception("No data received in raw format")
|
|
|
|
if len(data) > max_length:
|
|
raise Exception(
|
|
f"Message length {len(data)} exceeds maximum allowed {max_length}"
|
|
)
|
|
|
|
return data
|