breaking: identify protocol use now prefix-length messages by default. use use_varint_format param for old raw messages

This commit is contained in:
acul71
2025-07-13 17:24:56 +02:00
parent 4bbb08ce2d
commit 1c59653946
6 changed files with 209 additions and 57 deletions

View File

@ -27,7 +27,7 @@ if TYPE_CHECKING:
def get_default_protocols(host: IHost) -> "OrderedDict[TProtocol, StreamHandlerFn]":
return OrderedDict(
(
(IdentifyID, identify_handler_for(host, use_varint_format=False)),
(IdentifyID, identify_handler_for(host, use_varint_format=True)),
(PingID, handle_ping),
)
)

View File

@ -47,40 +47,57 @@ AGENT_VERSION = get_agent_version()
CONCURRENCY_LIMIT = 10
def identify_push_handler_for(host: IHost) -> StreamHandlerFn:
def identify_push_handler_for(
host: IHost, use_varint_format: bool = True
) -> StreamHandlerFn:
"""
Create a handler for the identify/push protocol.
This handler receives pushed identify messages from remote peers and updates
the local peerstore with the new information.
Args:
host: The libp2p host.
use_varint_format: If True, expect length-prefixed format; if False,
expect raw protobuf.
"""
async def handle_identify_push(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
try:
# Read length-prefixed identify message from the stream
# First read the varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if use_varint_format:
# Read length-prefixed identify message from the stream
# First read the varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if not length_bytes:
logger.warning("No length prefix received from peer %s", peer_id)
return
if not length_bytes:
logger.warning("No length prefix received from peer %s", peer_id)
return
msg_length = decode_varint_from_bytes(length_bytes)
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
logger.warning("Incomplete message received from peer %s", peer_id)
return
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
logger.warning("Incomplete message received from peer %s", peer_id)
return
else:
# Read raw protobuf message from the stream
data = b""
while True:
chunk = await stream.read(4096)
if not chunk:
break
data += chunk
identify_msg = Identify()
identify_msg.ParseFromString(data)
@ -162,6 +179,7 @@ async def push_identify_to_peer(
peer_id: ID,
observed_multiaddr: Multiaddr | None = None,
limit: trio.Semaphore = trio.Semaphore(CONCURRENCY_LIMIT),
use_varint_format: bool = True,
) -> bool:
"""
Push an identify message to a specific peer.
@ -169,10 +187,16 @@ async def push_identify_to_peer(
This function opens a stream to the peer using the identify/push protocol,
sends the identify message, and closes the stream.
Returns
-------
bool
True if the push was successful, False otherwise.
Args:
host: The libp2p host.
peer_id: The peer ID to push to.
observed_multiaddr: The observed multiaddress (optional).
limit: Semaphore for concurrency control.
use_varint_format: If True, send length-prefixed format; if False,
send raw protobuf.
Returns:
bool: True if the push was successful, False otherwise.
"""
async with limit:
@ -184,9 +208,13 @@ async def push_identify_to_peer(
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
response = identify_msg.SerializeToString()
# Send length-prefixed identify message
await stream.write(varint.encode_uvarint(len(response)))
await stream.write(response)
if use_varint_format:
# Send length-prefixed identify message
await stream.write(varint.encode_uvarint(len(response)))
await stream.write(response)
else:
# Send raw protobuf message
await stream.write(response)
# Close the stream
await stream.close()
@ -202,18 +230,37 @@ async def push_identify_to_peers(
host: IHost,
peer_ids: set[ID] | None = None,
observed_multiaddr: Multiaddr | None = None,
use_varint_format: bool = True,
) -> None:
"""
Push an identify message to multiple peers in parallel.
If peer_ids is None, push to all connected peers.
Args:
host: The libp2p host.
peer_ids: Set of peer IDs to push to (if None, push to all connected peers).
observed_multiaddr: The observed multiaddress (optional).
use_varint_format: If True, send length-prefixed format; if False,
send raw protobuf.
"""
if peer_ids is None:
# Get all connected peers
peer_ids = set(host.get_connected_peers())
# Create a single shared semaphore for concurrency control
limit = trio.Semaphore(CONCURRENCY_LIMIT)
# Push to each peer in parallel using a trio.Nursery
# limiting concurrent connections to 10
# limiting concurrent connections to CONCURRENCY_LIMIT
async with trio.open_nursery() as nursery:
for peer_id in peer_ids:
nursery.start_soon(push_identify_to_peer, host, peer_id, observed_multiaddr)
nursery.start_soon(
push_identify_to_peer,
host,
peer_id,
observed_multiaddr,
limit,
use_varint_format,
)