mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
feat: Add identify-push raw format support and yamux logging improvements
- Add comprehensive integration tests for identify-push protocol - Support both raw protobuf and varint message formats - Improve yamux logging integration with LIBP2P_DEBUG - Fix RawConnError handling to reduce log noise - Add Ctrl+C handling to identify examples - Enhance identify-push listener/dialer demo Fixes: #784
This commit is contained in:
@ -9,6 +9,7 @@ from libp2p.utils.varint import (
|
||||
read_varint_prefixed_bytes,
|
||||
decode_varint_from_bytes,
|
||||
decode_varint_with_size,
|
||||
read_length_prefixed_protobuf,
|
||||
)
|
||||
from libp2p.utils.version import (
|
||||
get_agent_version,
|
||||
@ -24,4 +25,5 @@ __all__ = [
|
||||
"read_varint_prefixed_bytes",
|
||||
"decode_varint_from_bytes",
|
||||
"decode_varint_with_size",
|
||||
"read_length_prefixed_protobuf",
|
||||
]
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
from typing import BinaryIO
|
||||
|
||||
from libp2p.abc import INetStream
|
||||
from libp2p.exceptions import (
|
||||
ParseError,
|
||||
)
|
||||
@ -25,42 +27,41 @@ HIGH_MASK = 2**7
|
||||
SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7
|
||||
|
||||
|
||||
def encode_uvarint(number: int) -> bytes:
|
||||
"""Pack `number` into varint bytes."""
|
||||
buf = b""
|
||||
while True:
|
||||
towrite = number & 0x7F
|
||||
number >>= 7
|
||||
if number:
|
||||
buf += bytes((towrite | 0x80,))
|
||||
else:
|
||||
buf += bytes((towrite,))
|
||||
def encode_uvarint(value: int) -> bytes:
|
||||
"""Encode an unsigned integer as a varint."""
|
||||
if value < 0:
|
||||
raise ValueError("Cannot encode negative value as uvarint")
|
||||
|
||||
result = bytearray()
|
||||
while value >= 0x80:
|
||||
result.append((value & 0x7F) | 0x80)
|
||||
value >>= 7
|
||||
result.append(value & 0x7F)
|
||||
return bytes(result)
|
||||
|
||||
|
||||
def decode_uvarint(data: bytes) -> int:
|
||||
"""Decode a varint from bytes."""
|
||||
if not data:
|
||||
raise ParseError("Unexpected end of data")
|
||||
|
||||
result = 0
|
||||
shift = 0
|
||||
|
||||
for byte in data:
|
||||
result |= (byte & 0x7F) << shift
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
return buf
|
||||
shift += 7
|
||||
if shift >= 64:
|
||||
raise ValueError("Varint too long")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def decode_varint_from_bytes(data: bytes) -> int:
|
||||
"""
|
||||
Decode a varint from bytes and return the value.
|
||||
|
||||
This is a synchronous version of decode_uvarint_from_stream for already-read bytes.
|
||||
"""
|
||||
res = 0
|
||||
for shift in itertools.count(0, 7):
|
||||
if shift > SHIFT_64_BIT_MAX:
|
||||
raise ParseError("Integer is too large...")
|
||||
|
||||
if not data:
|
||||
raise ParseError("Unexpected end of data")
|
||||
|
||||
value = data[0]
|
||||
data = data[1:]
|
||||
|
||||
res += (value & LOW_MASK) << shift
|
||||
|
||||
if not value & HIGH_MASK:
|
||||
break
|
||||
return res
|
||||
"""Decode a varint from bytes (alias for decode_uvarint for backward comp)."""
|
||||
return decode_uvarint(data)
|
||||
|
||||
|
||||
async def decode_uvarint_from_stream(reader: Reader) -> int:
|
||||
@ -84,34 +85,33 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
|
||||
|
||||
def decode_varint_with_size(data: bytes) -> tuple[int, int]:
|
||||
"""
|
||||
Decode a varint from bytes and return (value, bytes_consumed).
|
||||
Returns (0, 0) if the data doesn't start with a valid varint.
|
||||
Decode a varint from bytes and return both the value and the number of bytes
|
||||
consumed.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: (value, bytes_consumed)
|
||||
|
||||
"""
|
||||
try:
|
||||
# Calculate how many bytes the varint consumes
|
||||
varint_size = 0
|
||||
for i, byte in enumerate(data):
|
||||
varint_size += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
result = 0
|
||||
shift = 0
|
||||
bytes_consumed = 0
|
||||
|
||||
if varint_size == 0:
|
||||
return 0, 0
|
||||
for byte in data:
|
||||
result |= (byte & 0x7F) << shift
|
||||
bytes_consumed += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
shift += 7
|
||||
if shift >= 64:
|
||||
raise ValueError("Varint too long")
|
||||
|
||||
# Extract just the varint bytes
|
||||
varint_bytes = data[:varint_size]
|
||||
|
||||
# Decode the varint
|
||||
value = decode_varint_from_bytes(varint_bytes)
|
||||
|
||||
return value, varint_size
|
||||
except Exception:
|
||||
return 0, 0
|
||||
return result, bytes_consumed
|
||||
|
||||
|
||||
def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
|
||||
varint_len = encode_uvarint(len(msg_bytes))
|
||||
return varint_len + msg_bytes
|
||||
def encode_varint_prefixed(data: bytes) -> bytes:
|
||||
"""Encode data with a varint length prefix."""
|
||||
length_bytes = encode_uvarint(len(data))
|
||||
return length_bytes + data
|
||||
|
||||
|
||||
async def read_varint_prefixed_bytes(reader: Reader) -> bytes:
|
||||
@ -138,3 +138,95 @@ async def read_delim(reader: Reader) -> bytes:
|
||||
f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes!r}'
|
||||
)
|
||||
return msg_bytes[:-1]
|
||||
|
||||
|
||||
def read_varint_prefixed_bytes_sync(
|
||||
stream: BinaryIO, max_length: int = 1024 * 1024
|
||||
) -> bytes:
|
||||
"""
|
||||
Read varint-prefixed bytes from a stream.
|
||||
|
||||
Args:
|
||||
stream: A stream-like object with a read() method
|
||||
max_length: Maximum allowed data length to prevent memory exhaustion
|
||||
|
||||
Returns:
|
||||
bytes: The data without the length prefix
|
||||
|
||||
Raises:
|
||||
ValueError: If the length prefix is invalid or too large
|
||||
EOFError: If the stream ends unexpectedly
|
||||
|
||||
"""
|
||||
# Read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
byte_data = stream.read(1)
|
||||
if not byte_data:
|
||||
raise EOFError("Stream ended while reading varint length prefix")
|
||||
|
||||
length_bytes += byte_data
|
||||
if byte_data[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
# Decode the length
|
||||
length = decode_uvarint(length_bytes)
|
||||
|
||||
if length > max_length:
|
||||
raise ValueError(f"Data length {length} exceeds maximum allowed {max_length}")
|
||||
|
||||
# Read the data
|
||||
data = stream.read(length)
|
||||
if len(data) != length:
|
||||
raise EOFError(f"Expected {length} bytes, got {len(data)}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def read_length_prefixed_protobuf(
|
||||
stream: INetStream, use_varint_format: bool = True, max_length: int = 1024 * 1024
|
||||
) -> bytes:
|
||||
"""Read a protobuf message from a stream, handling both formats."""
|
||||
if use_varint_format:
|
||||
# Read length-prefixed protobuf message from the stream
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
raise Exception("No length prefix received")
|
||||
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
if msg_length > max_length:
|
||||
raise Exception(
|
||||
f"Message length {msg_length} exceeds maximum allowed {max_length}"
|
||||
)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
raise Exception(
|
||||
f"Incomplete message: expected {msg_length}, got {len(data)}"
|
||||
)
|
||||
|
||||
return data
|
||||
else:
|
||||
# Read raw protobuf message from the stream
|
||||
# For raw format, read all available data in one go
|
||||
data = await stream.read()
|
||||
|
||||
# If we got no data, raise an exception
|
||||
if not data:
|
||||
raise Exception("No data received in raw format")
|
||||
|
||||
if len(data) > max_length:
|
||||
raise Exception(
|
||||
f"Message length {len(data)} exceeds maximum allowed {max_length}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user