Merge branch 'main' into feature/bootstrap

This commit is contained in:
Manu Sheel Gupta
2025-07-16 15:00:35 -07:00
committed by GitHub
24 changed files with 2294 additions and 63 deletions

View File

@ -26,5 +26,8 @@ if TYPE_CHECKING:
def get_default_protocols(host: IHost) -> "OrderedDict[TProtocol, StreamHandlerFn]":
return OrderedDict(
((IdentifyID, identify_handler_for(host)), (PingID, handle_ping))
(
(IdentifyID, identify_handler_for(host, use_varint_format=True)),
(PingID, handle_ping),
)
)

View File

@ -16,7 +16,9 @@ from libp2p.network.stream.exceptions import (
StreamClosed,
)
from libp2p.utils import (
decode_varint_with_size,
get_agent_version,
varint,
)
from .pb.identify_pb2 import (
@ -72,7 +74,47 @@ def _mk_identify_protobuf(
)
def identify_handler_for(host: IHost) -> StreamHandlerFn:
def parse_identify_response(response: bytes) -> Identify:
"""
Parse identify response that could be either:
- Old format: raw protobuf
- New format: length-prefixed protobuf
This function provides backward and forward compatibility.
"""
# Try new format first: length-prefixed protobuf
if len(response) >= 1:
length, varint_size = decode_varint_with_size(response)
if varint_size > 0 and length > 0 and varint_size + length <= len(response):
protobuf_data = response[varint_size : varint_size + length]
try:
identify_response = Identify()
identify_response.ParseFromString(protobuf_data)
# Sanity check: must have agent_version (protocol_version is optional)
if identify_response.agent_version:
logger.debug(
"Parsed length-prefixed identify response (new format)"
)
return identify_response
except Exception:
pass # Fall through to old format
# Fall back to old format: raw protobuf
try:
identify_response = Identify()
identify_response.ParseFromString(response)
logger.debug("Parsed raw protobuf identify response (old format)")
return identify_response
except Exception as e:
logger.error(f"Failed to parse identify response: {e}")
logger.error(f"Response length: {len(response)}")
logger.error(f"Response hex: {response.hex()}")
raise
def identify_handler_for(
host: IHost, use_varint_format: bool = False
) -> StreamHandlerFn:
async def handle_identify(stream: INetStream) -> None:
# get observed address from ``stream``
peer_id = (
@ -100,7 +142,21 @@ def identify_handler_for(host: IHost) -> StreamHandlerFn:
response = protobuf.SerializeToString()
try:
await stream.write(response)
if use_varint_format:
# Send length-prefixed protobuf message (new format)
await stream.write(varint.encode_uvarint(len(response)))
await stream.write(response)
logger.debug(
"Sent new format (length-prefixed) identify response to %s",
peer_id,
)
else:
# Send raw protobuf message (old format for backward compatibility)
await stream.write(response)
logger.debug(
"Sent old format (raw protobuf) identify response to %s",
peer_id,
)
except StreamClosed:
logger.debug("Fail to respond to %s request: stream closed", ID)
else:

View File

@ -25,6 +25,10 @@ from libp2p.peer.id import (
)
from libp2p.utils import (
get_agent_version,
varint,
)
from libp2p.utils.varint import (
decode_varint_from_bytes,
)
from ..identify.identify import (
@ -43,20 +47,69 @@ AGENT_VERSION = get_agent_version()
CONCURRENCY_LIMIT = 10
def identify_push_handler_for(host: IHost) -> StreamHandlerFn:
def identify_push_handler_for(
host: IHost, use_varint_format: bool = True
) -> StreamHandlerFn:
"""
Create a handler for the identify/push protocol.
This handler receives pushed identify messages from remote peers and updates
the local peerstore with the new information.
Args:
host: The libp2p host.
use_varint_format: True=length-prefixed, False=raw protobuf.
"""
async def handle_identify_push(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
try:
# Read the identify message from the stream
data = await stream.read()
if use_varint_format:
# Read length-prefixed identify message from the stream
# First read the varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if not length_bytes:
logger.warning("No length prefix received from peer %s", peer_id)
return
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
logger.warning("Incomplete message received from peer %s", peer_id)
return
else:
# Read raw protobuf message from the stream
# For raw format, we need to read all data before the stream is closed
data = b""
try:
# Read all available data in a single operation
data = await stream.read()
except StreamClosed:
# Try to read any remaining data
try:
data = await stream.read()
except Exception:
pass
# If we got no data, log a warning and return
if not data:
logger.warning(
"No data received in raw format from peer %s", peer_id
)
return
identify_msg = Identify()
identify_msg.ParseFromString(data)
@ -137,6 +190,7 @@ async def push_identify_to_peer(
peer_id: ID,
observed_multiaddr: Multiaddr | None = None,
limit: trio.Semaphore = trio.Semaphore(CONCURRENCY_LIMIT),
use_varint_format: bool = True,
) -> bool:
"""
Push an identify message to a specific peer.
@ -144,10 +198,15 @@ async def push_identify_to_peer(
This function opens a stream to the peer using the identify/push protocol,
sends the identify message, and closes the stream.
Returns
-------
bool
True if the push was successful, False otherwise.
Args:
host: The libp2p host.
peer_id: The peer ID to push to.
observed_multiaddr: The observed multiaddress (optional).
limit: Semaphore for concurrency control.
use_varint_format: True=length-prefixed, False=raw protobuf.
Returns:
bool: True if the push was successful, False otherwise.
"""
async with limit:
@ -159,8 +218,13 @@ async def push_identify_to_peer(
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
response = identify_msg.SerializeToString()
# Send the identify message
await stream.write(response)
if use_varint_format:
# Send length-prefixed identify message
await stream.write(varint.encode_uvarint(len(response)))
await stream.write(response)
else:
# Send raw protobuf message
await stream.write(response)
# Close the stream
await stream.close()
@ -176,18 +240,36 @@ async def push_identify_to_peers(
host: IHost,
peer_ids: set[ID] | None = None,
observed_multiaddr: Multiaddr | None = None,
use_varint_format: bool = True,
) -> None:
"""
Push an identify message to multiple peers in parallel.
If peer_ids is None, push to all connected peers.
Args:
host: The libp2p host.
peer_ids: Set of peer IDs to push to (if None, push to all connected peers).
observed_multiaddr: The observed multiaddress (optional).
use_varint_format: True=length-prefixed, False=raw protobuf.
"""
if peer_ids is None:
# Get all connected peers
peer_ids = set(host.get_connected_peers())
# Create a single shared semaphore for concurrency control
limit = trio.Semaphore(CONCURRENCY_LIMIT)
# Push to each peer in parallel using a trio.Nursery
# limiting concurrent connections to 10
# limiting concurrent connections to CONCURRENCY_LIMIT
async with trio.open_nursery() as nursery:
for peer_id in peer_ids:
nursery.start_soon(push_identify_to_peer, host, peer_id, observed_multiaddr)
nursery.start_soon(
push_identify_to_peer,
host,
peer_id,
observed_multiaddr,
limit,
use_varint_format,
)

View File

@ -7,6 +7,8 @@ from libp2p.utils.varint import (
encode_varint_prefixed,
read_delim,
read_varint_prefixed_bytes,
decode_varint_from_bytes,
decode_varint_with_size,
)
from libp2p.utils.version import (
get_agent_version,
@ -20,4 +22,6 @@ __all__ = [
"get_agent_version",
"read_delim",
"read_varint_prefixed_bytes",
"decode_varint_from_bytes",
"decode_varint_with_size",
]

View File

@ -39,12 +39,38 @@ def encode_uvarint(number: int) -> bytes:
return buf
def decode_varint_from_bytes(data: bytes) -> int:
"""
Decode a varint from bytes and return the value.
This is a synchronous version of decode_uvarint_from_stream for already-read bytes.
"""
res = 0
for shift in itertools.count(0, 7):
if shift > SHIFT_64_BIT_MAX:
raise ParseError("Integer is too large...")
if not data:
raise ParseError("Unexpected end of data")
value = data[0]
data = data[1:]
res += (value & LOW_MASK) << shift
if not value & HIGH_MASK:
break
return res
async def decode_uvarint_from_stream(reader: Reader) -> int:
"""https://en.wikipedia.org/wiki/LEB128."""
res = 0
for shift in itertools.count(0, 7):
if shift > SHIFT_64_BIT_MAX:
raise ParseError("TODO: better exception msg: Integer is too large...")
raise ParseError(
"Varint decoding error: integer exceeds maximum size of 64 bits."
)
byte = await read_exactly(reader, 1)
value = byte[0]
@ -56,6 +82,33 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
return res
def decode_varint_with_size(data: bytes) -> tuple[int, int]:
"""
Decode a varint from bytes and return (value, bytes_consumed).
Returns (0, 0) if the data doesn't start with a valid varint.
"""
try:
# Calculate how many bytes the varint consumes
varint_size = 0
for i, byte in enumerate(data):
varint_size += 1
if (byte & 0x80) == 0:
break
if varint_size == 0:
return 0, 0
# Extract just the varint bytes
varint_bytes = data[:varint_size]
# Decode the varint
value = decode_varint_from_bytes(varint_bytes)
return value, varint_size
except Exception:
return 0, 0
def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
varint_len = encode_uvarint(len(msg_bytes))
return varint_len + msg_bytes