Merge branch 'main' into py-multiaddr

This commit is contained in:
Manu Sheel Gupta
2025-07-21 06:27:11 -07:00
committed by GitHub
11 changed files with 1181 additions and 314 deletions

View File

@ -1,6 +1,7 @@
import argparse import argparse
import base64 import base64
import logging import logging
import sys
import multiaddr import multiaddr
import trio import trio
@ -112,7 +113,12 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
# Replace the handler with our custom one # Replace the handler with our custom one
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler) host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler)
await trio.sleep_forever() try:
await trio.sleep_forever()
except KeyboardInterrupt:
print("\n🛑 Shutting down listener...")
logger.info("Listener interrupted by user")
return
else: else:
# Create second host (dialer) # Create second host (dialer)
@ -147,38 +153,13 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
try: try:
print("Starting identify protocol...") print("Starting identify protocol...")
# Read the response properly based on the format # Read the response using the utility function
if use_varint_format: from libp2p.utils.varint import read_length_prefixed_protobuf
# For length-prefixed format, read varint length first
from libp2p.utils.varint import decode_varint_from_bytes
# Read varint length prefix response = await read_length_prefixed_protobuf(
length_bytes = b"" stream, use_varint_format
while True: )
b = await stream.read(1) full_response = response
if not b:
raise Exception("Stream closed while reading varint length")
length_bytes += b
if b[0] & 0x80 == 0:
break
msg_length = decode_varint_from_bytes(length_bytes)
print(f"Expected message length: {msg_length} bytes")
# Read the protobuf message
response = await stream.read(msg_length)
if len(response) != msg_length:
raise Exception(
f"Incomplete message: expected {msg_length} bytes, "
f"got {len(response)}"
)
# Combine length prefix and message
full_response = length_bytes + response
else:
# For raw format, read all available data
response = await stream.read(8192)
full_response = response
await stream.close() await stream.close()
@ -254,6 +235,7 @@ def main() -> None:
"length-prefixed (new format)" "length-prefixed (new format)"
), ),
) )
args = parser.parse_args() args = parser.parse_args()
# Determine format: raw format if --raw-format is specified, otherwise # Determine format: raw format if --raw-format is specified, otherwise
@ -261,9 +243,19 @@ def main() -> None:
use_varint_format = not args.raw_format use_varint_format = not args.raw_format
try: try:
trio.run(run, *(args.port, args.destination, use_varint_format)) if args.destination:
# Run in dialer mode
trio.run(run, *(args.port, args.destination, use_varint_format))
else:
# Run in listener mode
trio.run(run, *(args.port, args.destination, use_varint_format))
except KeyboardInterrupt: except KeyboardInterrupt:
pass print("\n👋 Goodbye!")
logger.info("Application interrupted by user")
except Exception as e:
print(f"\n❌ Error: {str(e)}")
logger.error("Error: %s", str(e))
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -11,23 +11,26 @@ This example shows how to:
import logging import logging
import multiaddr
import trio import trio
from libp2p import ( from libp2p import (
new_host, new_host,
) )
from libp2p.abc import (
INetStream,
)
from libp2p.crypto.secp256k1 import ( from libp2p.crypto.secp256k1 import (
create_new_key_pair, create_new_key_pair,
) )
from libp2p.custom_types import ( from libp2p.custom_types import (
TProtocol, TProtocol,
) )
from libp2p.identity.identify import ( from libp2p.identity.identify.pb.identify_pb2 import (
identify_handler_for, Identify,
) )
from libp2p.identity.identify_push import ( from libp2p.identity.identify_push import (
ID_PUSH, ID_PUSH,
identify_push_handler_for,
push_identify_to_peer, push_identify_to_peer,
) )
from libp2p.peer.peerinfo import ( from libp2p.peer.peerinfo import (
@ -38,8 +41,145 @@ from libp2p.peer.peerinfo import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_custom_identify_handler(host, host_name: str):
"""Create a custom identify handler that displays received information."""
async def handle_identify(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
print(f"\n🔍 {host_name} received identify request from peer: {peer_id}")
# Get the standard identify response using the existing function
from libp2p.identity.identify.identify import (
_mk_identify_protobuf,
_remote_address_to_multiaddr,
)
# Get observed address
observed_multiaddr = None
try:
remote_address = stream.get_remote_address()
if remote_address:
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
except Exception:
pass
# Build the identify protobuf
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
response_data = identify_msg.SerializeToString()
print(f" 📋 {host_name} identify information:")
if identify_msg.HasField("protocol_version"):
print(f" Protocol Version: {identify_msg.protocol_version}")
if identify_msg.HasField("agent_version"):
print(f" Agent Version: {identify_msg.agent_version}")
if identify_msg.HasField("public_key"):
print(f" Public Key: {identify_msg.public_key.hex()[:16]}...")
if identify_msg.listen_addrs:
print(" Listen Addresses:")
for addr_bytes in identify_msg.listen_addrs:
addr = multiaddr.Multiaddr(addr_bytes)
print(f" - {addr}")
if identify_msg.protocols:
print(" Supported Protocols:")
for protocol in identify_msg.protocols:
print(f" - {protocol}")
# Send the response
await stream.write(response_data)
await stream.close()
return handle_identify
def create_custom_identify_push_handler(host, host_name: str):
"""Create a custom identify/push handler that displays received information."""
async def handle_identify_push(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
print(f"\n📤 {host_name} received identify/push from peer: {peer_id}")
try:
# Read the identify message using the utility function
from libp2p.utils.varint import read_length_prefixed_protobuf
data = await read_length_prefixed_protobuf(stream, use_varint_format=True)
# Parse the identify message
identify_msg = Identify()
identify_msg.ParseFromString(data)
print(" 📋 Received identify information:")
if identify_msg.HasField("protocol_version"):
print(f" Protocol Version: {identify_msg.protocol_version}")
if identify_msg.HasField("agent_version"):
print(f" Agent Version: {identify_msg.agent_version}")
if identify_msg.HasField("public_key"):
print(f" Public Key: {identify_msg.public_key.hex()[:16]}...")
if identify_msg.HasField("observed_addr") and identify_msg.observed_addr:
observed_addr = multiaddr.Multiaddr(identify_msg.observed_addr)
print(f" Observed Address: {observed_addr}")
if identify_msg.listen_addrs:
print(" Listen Addresses:")
for addr_bytes in identify_msg.listen_addrs:
addr = multiaddr.Multiaddr(addr_bytes)
print(f" - {addr}")
if identify_msg.protocols:
print(" Supported Protocols:")
for protocol in identify_msg.protocols:
print(f" - {protocol}")
# Update the peerstore with the new information
from libp2p.identity.identify_push.identify_push import (
_update_peerstore_from_identify,
)
await _update_peerstore_from_identify(
host.get_peerstore(), peer_id, identify_msg
)
print(f"{host_name} updated peerstore with new information")
except Exception as e:
print(f" ❌ Error processing identify/push: {e}")
finally:
await stream.close()
return handle_identify_push
async def display_peerstore_info(host, host_name: str, peer_id, description: str):
"""Display peerstore information for a specific peer."""
peerstore = host.get_peerstore()
try:
addrs = peerstore.addrs(peer_id)
except Exception:
addrs = []
try:
protocols = peerstore.get_protocols(peer_id)
except Exception:
protocols = []
print(f"\n📚 {host_name} peerstore for {description}:")
print(f" Peer ID: {peer_id}")
if addrs:
print(" Addresses:")
for addr in addrs:
print(f" - {addr}")
else:
print(" Addresses: None")
if protocols:
print(" Protocols:")
for protocol in protocols:
print(f" - {protocol}")
else:
print(" Protocols: None")
async def main() -> None: async def main() -> None:
print("\n==== Starting Identify-Push Example ====\n") print("\n==== Starting Enhanced Identify-Push Example ====\n")
# Create key pairs for the two hosts # Create key pairs for the two hosts
key_pair_1 = create_new_key_pair() key_pair_1 = create_new_key_pair()
@ -48,45 +188,49 @@ async def main() -> None:
# Create the first host # Create the first host
host_1 = new_host(key_pair=key_pair_1) host_1 = new_host(key_pair=key_pair_1)
# Set up the identify and identify/push handlers # Set up custom identify and identify/push handlers
host_1.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_1)) host_1.set_stream_handler(
host_1.set_stream_handler(ID_PUSH, identify_push_handler_for(host_1)) TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_1, "Host 1")
)
host_1.set_stream_handler(
ID_PUSH, create_custom_identify_push_handler(host_1, "Host 1")
)
# Create the second host # Create the second host
host_2 = new_host(key_pair=key_pair_2) host_2 = new_host(key_pair=key_pair_2)
# Set up the identify and identify/push handlers # Set up custom identify and identify/push handlers
host_2.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_2)) host_2.set_stream_handler(
host_2.set_stream_handler(ID_PUSH, identify_push_handler_for(host_2)) TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_2, "Host 2")
)
host_2.set_stream_handler(
ID_PUSH, create_custom_identify_push_handler(host_2, "Host 2")
)
# Start listening on random ports using the run context manager # Start listening on random ports using the run context manager
import multiaddr
listen_addr_1 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0") listen_addr_1 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
listen_addr_2 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0") listen_addr_2 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
async with host_1.run([listen_addr_1]), host_2.run([listen_addr_2]): async with host_1.run([listen_addr_1]), host_2.run([listen_addr_2]):
# Get the addresses of both hosts # Get the addresses of both hosts
addr_1 = host_1.get_addrs()[0] addr_1 = host_1.get_addrs()[0]
logger.info(f"Host 1 listening on {addr_1}")
print(f"Host 1 listening on {addr_1}")
print(f"Peer ID: {host_1.get_id().pretty()}")
addr_2 = host_2.get_addrs()[0] addr_2 = host_2.get_addrs()[0]
logger.info(f"Host 2 listening on {addr_2}")
print(f"Host 2 listening on {addr_2}")
print(f"Peer ID: {host_2.get_id().pretty()}")
print("\nConnecting Host 2 to Host 1...") print("🏠 Host Configuration:")
print(f" Host 1: {addr_1}")
print(f" Host 1 Peer ID: {host_1.get_id().pretty()}")
print(f" Host 2: {addr_2}")
print(f" Host 2 Peer ID: {host_2.get_id().pretty()}")
print("\n🔗 Connecting Host 2 to Host 1...")
# Connect host_2 to host_1 # Connect host_2 to host_1
peer_info = info_from_p2p_addr(addr_1) peer_info = info_from_p2p_addr(addr_1)
await host_2.connect(peer_info) await host_2.connect(peer_info)
logger.info("Host 2 connected to Host 1") print("Host 2 successfully connected to Host 1")
print("Host 2 successfully connected to Host 1")
# Run the identify protocol from host_2 to host_1 # Run the identify protocol from host_2 to host_1
# (so Host 1 learns Host 2's address) print("\n🔄 Running identify protocol (Host 2 → Host 1)...")
from libp2p.identity.identify.identify import ID as IDENTIFY_PROTOCOL_ID from libp2p.identity.identify.identify import ID as IDENTIFY_PROTOCOL_ID
stream = await host_2.new_stream(host_1.get_id(), (IDENTIFY_PROTOCOL_ID,)) stream = await host_2.new_stream(host_1.get_id(), (IDENTIFY_PROTOCOL_ID,))
@ -94,64 +238,58 @@ async def main() -> None:
await stream.close() await stream.close()
# Run the identify protocol from host_1 to host_2 # Run the identify protocol from host_1 to host_2
# (so Host 2 learns Host 1's address) print("\n🔄 Running identify protocol (Host 1 → Host 2)...")
stream = await host_1.new_stream(host_2.get_id(), (IDENTIFY_PROTOCOL_ID,)) stream = await host_1.new_stream(host_2.get_id(), (IDENTIFY_PROTOCOL_ID,))
response = await stream.read() response = await stream.read()
await stream.close() await stream.close()
# --- NEW CODE: Update Host 1's peerstore with Host 2's addresses --- # Update Host 1's peerstore with Host 2's addresses
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
identify_msg = Identify() identify_msg = Identify()
identify_msg.ParseFromString(response) identify_msg.ParseFromString(response)
peerstore_1 = host_1.get_peerstore() peerstore_1 = host_1.get_peerstore()
peer_id_2 = host_2.get_id() peer_id_2 = host_2.get_id()
for addr_bytes in identify_msg.listen_addrs: for addr_bytes in identify_msg.listen_addrs:
maddr = multiaddr.Multiaddr(addr_bytes) maddr = multiaddr.Multiaddr(addr_bytes)
# TTL can be any positive int peerstore_1.add_addr(peer_id_2, maddr, ttl=3600)
peerstore_1.add_addr(
peer_id_2,
maddr,
ttl=3600,
)
# --- END NEW CODE ---
# Now Host 1's peerstore should have Host 2's address # Display peerstore information before push
peerstore_1 = host_1.get_peerstore() await display_peerstore_info(
peer_id_2 = host_2.get_id() host_1, "Host 1", peer_id_2, "Host 2 (before push)"
addrs_1_for_2 = peerstore_1.addrs(peer_id_2)
logger.info(
f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: "
f"{addrs_1_for_2}"
)
print(
f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: "
f"{addrs_1_for_2}"
) )
# Push identify information from host_1 to host_2 # Push identify information from host_1 to host_2
logger.info("Host 1 pushing identify information to Host 2") print("\n📤 Host 1 pushing identify information to Host 2...")
print("\nHost 1 pushing identify information to Host 2...")
try: try:
# Call push_identify_to_peer which now returns a boolean
success = await push_identify_to_peer(host_1, host_2.get_id()) success = await push_identify_to_peer(host_1, host_2.get_id())
if success: if success:
logger.info("Identify push completed successfully") print("Identify push completed successfully!")
print("Identify push completed successfully!")
else: else:
logger.warning("Identify push didn't complete successfully") print("⚠️ Identify push didn't complete successfully")
print("\nWarning: Identify push didn't complete successfully")
except Exception as e: except Exception as e:
logger.error(f"Error during identify push: {str(e)}") print(f"Error during identify push: {str(e)}")
print(f"\nError during identify push: {str(e)}")
# Add this at the end of your async with block: # Give a moment for the identify/push processing to complete
await trio.sleep(0.5) # Give background tasks time to finish await trio.sleep(0.5)
# Display peerstore information after push
await display_peerstore_info(host_1, "Host 1", peer_id_2, "Host 2 (after push)")
await display_peerstore_info(
host_2, "Host 2", host_1.get_id(), "Host 1 (after push)"
)
# Give more time for background tasks to finish and connections to stabilize
print("\n⏳ Waiting for background tasks to complete...")
await trio.sleep(1.0)
# Gracefully close connections to prevent connection errors
print("🔌 Closing connections...")
await host_2.disconnect(host_1.get_id())
await trio.sleep(0.2)
print("\n🎉 Example completed successfully!")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -41,6 +41,9 @@ from libp2p.identity.identify import (
ID as ID_IDENTIFY, ID as ID_IDENTIFY,
identify_handler_for, identify_handler_for,
) )
from libp2p.identity.identify.identify import (
_remote_address_to_multiaddr,
)
from libp2p.identity.identify.pb.identify_pb2 import ( from libp2p.identity.identify.pb.identify_pb2 import (
Identify, Identify,
) )
@ -72,40 +75,30 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
async def handle_identify_push(stream: INetStream) -> None: async def handle_identify_push(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id peer_id = stream.muxed_conn.peer_id
# Get remote address information
try: try:
if use_varint_format: remote_address = stream.get_remote_address()
# Read length-prefixed identify message from the stream if remote_address:
from libp2p.utils.varint import decode_varint_from_bytes observed_multiaddr = _remote_address_to_multiaddr(remote_address)
logger.info(
"Connection from remote peer %s, address: %s, multiaddr: %s",
peer_id,
remote_address,
observed_multiaddr,
)
print(f"\n🔗 Received identify/push request from peer: {peer_id}")
# Add the peer ID to create a complete multiaddr
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
print(f" Remote address: {complete_multiaddr}")
except Exception as e:
logger.error("Error getting remote address: %s", e)
print(f"\n🔗 Received identify/push request from peer: {peer_id}")
# First read the varint length prefix try:
length_bytes = b"" # Use the utility function to read the protobuf message
while True: from libp2p.utils.varint import read_length_prefixed_protobuf
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if not length_bytes: data = await read_length_prefixed_protobuf(stream, use_varint_format)
logger.warning("No length prefix received from peer %s", peer_id)
return
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
logger.warning("Incomplete message received from peer %s", peer_id)
return
else:
# Read raw protobuf message from the stream
data = b""
while True:
chunk = await stream.read(4096)
if not chunk:
break
data += chunk
identify_msg = Identify() identify_msg = Identify()
identify_msg.ParseFromString(data) identify_msg.ParseFromString(data)
@ -155,11 +148,41 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
await _update_peerstore_from_identify(peerstore, peer_id, identify_msg) await _update_peerstore_from_identify(peerstore, peer_id, identify_msg)
logger.info("Successfully processed identify/push from peer %s", peer_id) logger.info("Successfully processed identify/push from peer %s", peer_id)
print(f"\nSuccessfully processed identify/push from peer {peer_id}") print(f"Successfully processed identify/push from peer {peer_id}")
except Exception as e: except Exception as e:
logger.error("Error processing identify/push from %s: %s", peer_id, e) error_msg = str(e)
print(f"\nError processing identify/push from {peer_id}: {e}") logger.error(
"Error processing identify/push from %s: %s", peer_id, error_msg
)
print(f"\nError processing identify/push from {peer_id}: {error_msg}")
# Check for specific format mismatch errors
if (
"Error parsing message" in error_msg
or "DecodeError" in error_msg
or "ParseFromString" in error_msg
):
print("\n" + "=" * 60)
print("FORMAT MISMATCH DETECTED!")
print("=" * 60)
if use_varint_format:
print(
"You are using length-prefixed format (default) but the "
"dialer is using raw protobuf format."
)
print("\nTo fix this, run the dialer with the --raw-format flag:")
print(
"identify-push-listener-dialer-demo --raw-format -d <ADDRESS>"
)
else:
print("You are using raw protobuf format but the dialer")
print("is using length-prefixed format (default).")
print(
"\nTo fix this, run the dialer without the --raw-format flag:"
)
print("identify-push-listener-dialer-demo -d <ADDRESS>")
print("=" * 60)
finally: finally:
# Close the stream after processing # Close the stream after processing
await stream.close() await stream.close()
@ -167,7 +190,9 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
return handle_identify_push return handle_identify_push
async def run_listener(port: int, use_varint_format: bool = True) -> None: async def run_listener(
port: int, use_varint_format: bool = True, raw_format_flag: bool = False
) -> None:
"""Run a host in listener mode.""" """Run a host in listener mode."""
format_name = "length-prefixed" if use_varint_format else "raw protobuf" format_name = "length-prefixed" if use_varint_format else "raw protobuf"
print( print(
@ -187,29 +212,41 @@ async def run_listener(port: int, use_varint_format: bool = True) -> None:
) )
host.set_stream_handler( host.set_stream_handler(
ID_IDENTIFY_PUSH, ID_IDENTIFY_PUSH,
identify_push_handler_for(host, use_varint_format=use_varint_format), custom_identify_push_handler_for(host, use_varint_format=use_varint_format),
) )
# Start listening # Start listening
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
async with host.run([listen_addr]): try:
addr = host.get_addrs()[0] async with host.run([listen_addr]):
logger.info("Listener host ready!") addr = host.get_addrs()[0]
print("Listener host ready!") logger.info("Listener host ready!")
print("Listener host ready!")
logger.info(f"Listening on: {addr}") logger.info(f"Listening on: {addr}")
print(f"Listening on: {addr}") print(f"Listening on: {addr}")
logger.info(f"Peer ID: {host.get_id().pretty()}") logger.info(f"Peer ID: {host.get_id().pretty()}")
print(f"Peer ID: {host.get_id().pretty()}") print(f"Peer ID: {host.get_id().pretty()}")
print("\nRun dialer with command:") print("\nRun dialer with command:")
print(f"identify-push-listener-dialer-demo -d {addr}") if raw_format_flag:
print("\nWaiting for incoming connections... (Ctrl+C to exit)") print(f"identify-push-listener-dialer-demo -d {addr} --raw-format")
else:
print(f"identify-push-listener-dialer-demo -d {addr}")
print("\nWaiting for incoming identify/push requests... (Ctrl+C to exit)")
# Keep running until interrupted # Keep running until interrupted
await trio.sleep_forever() try:
await trio.sleep_forever()
except KeyboardInterrupt:
print("\n🛑 Shutting down listener...")
logger.info("Listener interrupted by user")
return
except Exception as e:
logger.error(f"Listener error: {e}")
raise
async def run_dialer( async def run_dialer(
@ -256,7 +293,9 @@ async def run_dialer(
try: try:
await host.connect(peer_info) await host.connect(peer_info)
logger.info("Successfully connected to listener!") logger.info("Successfully connected to listener!")
print("Successfully connected to listener!") print("Successfully connected to listener!")
print(f" Connected to: {peer_info.peer_id}")
print(f" Full address: {destination}")
# Push identify information to the listener # Push identify information to the listener
logger.info("Pushing identify information to listener...") logger.info("Pushing identify information to listener...")
@ -270,7 +309,7 @@ async def run_dialer(
if success: if success:
logger.info("Identify push completed successfully!") logger.info("Identify push completed successfully!")
print("Identify push completed successfully!") print("Identify push completed successfully!")
logger.info("Example completed successfully!") logger.info("Example completed successfully!")
print("\nExample completed successfully!") print("\nExample completed successfully!")
@ -281,17 +320,57 @@ async def run_dialer(
logger.warning("Example completed with warnings.") logger.warning("Example completed with warnings.")
print("Example completed with warnings.") print("Example completed with warnings.")
except Exception as e: except Exception as e:
logger.error(f"Error during identify push: {str(e)}") error_msg = str(e)
print(f"\nError during identify push: {str(e)}") logger.error(f"Error during identify push: {error_msg}")
print(f"\nError during identify push: {error_msg}")
# Check for specific format mismatch errors
if (
"Error parsing message" in error_msg
or "DecodeError" in error_msg
or "ParseFromString" in error_msg
):
print("\n" + "=" * 60)
print("FORMAT MISMATCH DETECTED!")
print("=" * 60)
if use_varint_format:
print(
"You are using length-prefixed format (default) but the "
"listener is using raw protobuf format."
)
print(
"\nTo fix this, run the dialer with the --raw-format flag:"
)
print(
f"identify-push-listener-dialer-demo --raw-format -d "
f"{destination}"
)
else:
print("You are using raw protobuf format but the listener")
print("is using length-prefixed format (default).")
print(
"\nTo fix this, run the dialer without the --raw-format "
"flag:"
)
print(f"identify-push-listener-dialer-demo -d {destination}")
print("=" * 60)
logger.error("Example completed with errors.") logger.error("Example completed with errors.")
print("Example completed with errors.") print("Example completed with errors.")
# Continue execution despite the push error # Continue execution despite the push error
except Exception as e: except Exception as e:
logger.error(f"Error during dialer operation: {str(e)}") error_msg = str(e)
print(f"\nError during dialer operation: {str(e)}") if "unable to connect" in error_msg or "SwarmException" in error_msg:
raise print(f"\n❌ Cannot connect to peer: {peer_info.peer_id}")
print(f" Address: {destination}")
print(f" Error: {error_msg}")
print("\n💡 Make sure the peer is running and the address is correct.")
return
else:
logger.error(f"Error during dialer operation: {error_msg}")
print(f"\nError during dialer operation: {error_msg}")
raise
def main() -> None: def main() -> None:
@ -301,12 +380,21 @@ def main() -> None:
Without arguments, it runs as a listener on random port. Without arguments, it runs as a listener on random port.
With -d parameter, it runs as a dialer on random port. With -d parameter, it runs as a dialer on random port.
Port 0 (default) means the OS will automatically assign an available port.
This prevents port conflicts when running multiple instances.
Use --raw-format to send raw protobuf messages (old format) instead of Use --raw-format to send raw protobuf messages (old format) instead of
length-prefixed protobuf messages (new format, default). length-prefixed protobuf messages (new format, default).
""" """
parser = argparse.ArgumentParser(description=description) parser = argparse.ArgumentParser(description=description)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number") parser.add_argument(
"-p",
"--port",
default=0,
type=int,
help="source port number (0 = random available port)",
)
parser.add_argument( parser.add_argument(
"-d", "-d",
"--destination", "--destination",
@ -321,6 +409,7 @@ def main() -> None:
"length-prefixed (new format)" "length-prefixed (new format)"
), ),
) )
args = parser.parse_args() args = parser.parse_args()
# Determine format: raw format if --raw-format is specified, otherwise # Determine format: raw format if --raw-format is specified, otherwise
@ -333,12 +422,12 @@ def main() -> None:
trio.run(run_dialer, args.port, args.destination, use_varint_format) trio.run(run_dialer, args.port, args.destination, use_varint_format)
else: else:
# Run in listener mode with random available port if not specified # Run in listener mode with random available port if not specified
trio.run(run_listener, args.port, use_varint_format) trio.run(run_listener, args.port, use_varint_format, args.raw_format)
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nInterrupted by user") print("\n👋 Goodbye!")
logger.info("Interrupted by user") logger.info("Application interrupted by user")
except Exception as e: except Exception as e:
print(f"\nError: {str(e)}") print(f"\nError: {str(e)}")
logger.error("Error: %s", str(e)) logger.error("Error: %s", str(e))
sys.exit(1) sys.exit(1)

View File

@ -113,7 +113,7 @@ def parse_identify_response(response: bytes) -> Identify:
def identify_handler_for( def identify_handler_for(
host: IHost, use_varint_format: bool = False host: IHost, use_varint_format: bool = True
) -> StreamHandlerFn: ) -> StreamHandlerFn:
async def handle_identify(stream: INetStream) -> None: async def handle_identify(stream: INetStream) -> None:
# get observed address from ``stream`` # get observed address from ``stream``

View File

@ -28,7 +28,7 @@ from libp2p.utils import (
varint, varint,
) )
from libp2p.utils.varint import ( from libp2p.utils.varint import (
decode_varint_from_bytes, read_length_prefixed_protobuf,
) )
from ..identify.identify import ( from ..identify.identify import (
@ -66,49 +66,8 @@ def identify_push_handler_for(
peer_id = stream.muxed_conn.peer_id peer_id = stream.muxed_conn.peer_id
try: try:
if use_varint_format: # Use the utility function to read the protobuf message
# Read length-prefixed identify message from the stream data = await read_length_prefixed_protobuf(stream, use_varint_format)
# First read the varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if not length_bytes:
logger.warning("No length prefix received from peer %s", peer_id)
return
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
logger.warning("Incomplete message received from peer %s", peer_id)
return
else:
# Read raw protobuf message from the stream
# For raw format, we need to read all data before the stream is closed
data = b""
try:
# Read all available data in a single operation
data = await stream.read()
except StreamClosed:
# Try to read any remaining data
try:
data = await stream.read()
except Exception:
pass
# If we got no data, log a warning and return
if not data:
logger.warning(
"No data received in raw format from peer %s", peer_id
)
return
identify_msg = Identify() identify_msg = Identify()
identify_msg.ParseFromString(data) identify_msg.ParseFromString(data)
@ -119,6 +78,11 @@ def identify_push_handler_for(
) )
logger.debug("Successfully processed identify/push from peer %s", peer_id) logger.debug("Successfully processed identify/push from peer %s", peer_id)
# Send acknowledgment to indicate successful processing
# This ensures the sender knows the message was received before closing
await stream.write(b"OK")
except StreamClosed: except StreamClosed:
logger.debug( logger.debug(
"Stream closed while processing identify/push from %s", peer_id "Stream closed while processing identify/push from %s", peer_id
@ -127,7 +91,10 @@ def identify_push_handler_for(
logger.error("Error processing identify/push from %s: %s", peer_id, e) logger.error("Error processing identify/push from %s: %s", peer_id, e)
finally: finally:
# Close the stream after processing # Close the stream after processing
await stream.close() try:
await stream.close()
except Exception:
pass # Ignore errors when closing
return handle_identify_push return handle_identify_push
@ -226,7 +193,20 @@ async def push_identify_to_peer(
# Send raw protobuf message # Send raw protobuf message
await stream.write(response) await stream.write(response)
# Close the stream # Wait for acknowledgment from the receiver with timeout
# This ensures the message was processed before closing
try:
with trio.move_on_after(1.0): # 1 second timeout
ack = await stream.read(2) # Read "OK" acknowledgment
if ack != b"OK":
logger.warning(
"Unexpected acknowledgment from peer %s: %s", peer_id, ack
)
except Exception as e:
logger.debug("No acknowledgment received from peer %s: %s", peer_id, e)
# Continue anyway, as the message might have been processed
# Close the stream after acknowledgment (or timeout)
await stream.close() await stream.close()
logger.debug("Successfully pushed identify to peer %s", peer_id) logger.debug("Successfully pushed identify to peer %s", peer_id)

View File

@ -45,6 +45,9 @@ from libp2p.stream_muxer.exceptions import (
MuxedStreamReset, MuxedStreamReset,
) )
# Configure logger for this module
logger = logging.getLogger("libp2p.stream_muxer.yamux")
PROTOCOL_ID = "/yamux/1.0.0" PROTOCOL_ID = "/yamux/1.0.0"
TYPE_DATA = 0x0 TYPE_DATA = 0x0
TYPE_WINDOW_UPDATE = 0x1 TYPE_WINDOW_UPDATE = 0x1
@ -98,13 +101,13 @@ class YamuxStream(IMuxedStream):
# Flow control: Check if we have enough send window # Flow control: Check if we have enough send window
total_len = len(data) total_len = len(data)
sent = 0 sent = 0
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ") logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len: while sent < total_len:
# Wait for available window with timeout # Wait for available window with timeout
timeout = False timeout = False
async with self.window_lock: async with self.window_lock:
if self.send_window == 0: if self.send_window == 0:
logging.debug( logger.debug(
f"Stream {self.stream_id}: Window is zero, waiting for update" f"Stream {self.stream_id}: Window is zero, waiting for update"
) )
# Release lock and wait with timeout # Release lock and wait with timeout
@ -152,12 +155,12 @@ class YamuxStream(IMuxedStream):
""" """
if increment <= 0: if increment <= 0:
# If increment is zero or negative, skip sending update # If increment is zero or negative, skip sending update
logging.debug( logger.debug(
f"Stream {self.stream_id}: Skipping window update" f"Stream {self.stream_id}: Skipping window update"
f"(increment={increment})" f"(increment={increment})"
) )
return return
logging.debug( logger.debug(
f"Stream {self.stream_id}: Sending window update with increment={increment}" f"Stream {self.stream_id}: Sending window update with increment={increment}"
) )
@ -185,7 +188,7 @@ class YamuxStream(IMuxedStream):
# If the stream is closed for receiving and the buffer is empty, raise EOF # If the stream is closed for receiving and the buffer is empty, raise EOF
if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id): if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id):
logging.debug( logger.debug(
f"Stream {self.stream_id}: Stream closed for receiving and buffer empty" f"Stream {self.stream_id}: Stream closed for receiving and buffer empty"
) )
raise MuxedStreamEOF("Stream is closed for receiving") raise MuxedStreamEOF("Stream is closed for receiving")
@ -198,7 +201,7 @@ class YamuxStream(IMuxedStream):
# If buffer is not available, check if stream is closed # If buffer is not available, check if stream is closed
if buffer is None: if buffer is None:
logging.debug(f"Stream {self.stream_id}: No buffer available") logger.debug(f"Stream {self.stream_id}: No buffer available")
raise MuxedStreamEOF("Stream buffer closed") raise MuxedStreamEOF("Stream buffer closed")
# If we have data in buffer, process it # If we have data in buffer, process it
@ -210,34 +213,34 @@ class YamuxStream(IMuxedStream):
# Send window update for the chunk we just read # Send window update for the chunk we just read
async with self.window_lock: async with self.window_lock:
self.recv_window += len(chunk) self.recv_window += len(chunk)
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}") logger.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
await self.send_window_update(len(chunk), skip_lock=True) await self.send_window_update(len(chunk), skip_lock=True)
# If stream is closed (FIN received) and buffer is empty, break # If stream is closed (FIN received) and buffer is empty, break
if self.recv_closed and len(buffer) == 0: if self.recv_closed and len(buffer) == 0:
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer") logger.debug(f"Stream {self.stream_id}: Closed with empty buffer")
break break
# If stream was reset, raise reset error # If stream was reset, raise reset error
if self.reset_received: if self.reset_received:
logging.debug(f"Stream {self.stream_id}: Stream was reset") logger.debug(f"Stream {self.stream_id}: Stream was reset")
raise MuxedStreamReset("Stream was reset") raise MuxedStreamReset("Stream was reset")
# Wait for more data or stream closure # Wait for more data or stream closure
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN") logger.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
await self.conn.stream_events[self.stream_id].wait() await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event() self.conn.stream_events[self.stream_id] = trio.Event()
# After loop exit, first check if we have data to return # After loop exit, first check if we have data to return
if data: if data:
logging.debug( logger.debug(
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop" f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
) )
return data return data
# No data accumulated, now check why we exited the loop # No data accumulated, now check why we exited the loop
if self.conn.event_shutting_down.is_set(): if self.conn.event_shutting_down.is_set():
logging.debug(f"Stream {self.stream_id}: Connection shutting down") logger.debug(f"Stream {self.stream_id}: Connection shutting down")
raise MuxedStreamEOF("Connection shut down") raise MuxedStreamEOF("Connection shut down")
# Return empty data # Return empty data
@ -246,7 +249,7 @@ class YamuxStream(IMuxedStream):
data = await self.conn.read_stream(self.stream_id, n) data = await self.conn.read_stream(self.stream_id, n)
async with self.window_lock: async with self.window_lock:
self.recv_window += len(data) self.recv_window += len(data)
logging.debug( logger.debug(
f"Stream {self.stream_id}: Sending window update after read, " f"Stream {self.stream_id}: Sending window update after read, "
f"increment={len(data)}" f"increment={len(data)}"
) )
@ -255,7 +258,7 @@ class YamuxStream(IMuxedStream):
async def close(self) -> None: async def close(self) -> None:
if not self.send_closed: if not self.send_closed:
logging.debug(f"Half-closing stream {self.stream_id} (local end)") logger.debug(f"Half-closing stream {self.stream_id} (local end)")
header = struct.pack( header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0 YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
) )
@ -271,7 +274,7 @@ class YamuxStream(IMuxedStream):
async def reset(self) -> None: async def reset(self) -> None:
if not self.closed: if not self.closed:
logging.debug(f"Resetting stream {self.stream_id}") logger.debug(f"Resetting stream {self.stream_id}")
header = struct.pack( header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0 YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
) )
@ -349,7 +352,7 @@ class Yamux(IMuxedConn):
self._nursery: Nursery | None = None self._nursery: Nursery | None = None
async def start(self) -> None: async def start(self) -> None:
logging.debug(f"Starting Yamux for {self.peer_id}") logger.debug(f"Starting Yamux for {self.peer_id}")
if self.event_started.is_set(): if self.event_started.is_set():
return return
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
@ -362,7 +365,7 @@ class Yamux(IMuxedConn):
return self.is_initiator_value return self.is_initiator_value
async def close(self, error_code: int = GO_AWAY_NORMAL) -> None: async def close(self, error_code: int = GO_AWAY_NORMAL) -> None:
logging.debug(f"Closing Yamux connection with code {error_code}") logger.debug(f"Closing Yamux connection with code {error_code}")
async with self.streams_lock: async with self.streams_lock:
if not self.event_shutting_down.is_set(): if not self.event_shutting_down.is_set():
try: try:
@ -371,7 +374,7 @@ class Yamux(IMuxedConn):
) )
await self.secured_conn.write(header) await self.secured_conn.write(header)
except Exception as e: except Exception as e:
logging.debug(f"Failed to send GO_AWAY: {e}") logger.debug(f"Failed to send GO_AWAY: {e}")
self.event_shutting_down.set() self.event_shutting_down.set()
for stream in self.streams.values(): for stream in self.streams.values():
stream.closed = True stream.closed = True
@ -382,12 +385,12 @@ class Yamux(IMuxedConn):
self.stream_events.clear() self.stream_events.clear()
try: try:
await self.secured_conn.close() await self.secured_conn.close()
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}") logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
except Exception as e: except Exception as e:
logging.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}") logger.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
self.event_closed.set() self.event_closed.set()
if self.on_close: if self.on_close:
logging.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}") logger.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
if inspect.iscoroutinefunction(self.on_close): if inspect.iscoroutinefunction(self.on_close):
if self.on_close is not None: if self.on_close is not None:
await self.on_close() await self.on_close()
@ -416,7 +419,7 @@ class Yamux(IMuxedConn):
header = struct.pack( header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0 YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0
) )
logging.debug(f"Sending SYN header for stream {stream_id}") logger.debug(f"Sending SYN header for stream {stream_id}")
await self.secured_conn.write(header) await self.secured_conn.write(header)
return stream return stream
except Exception as e: except Exception as e:
@ -424,32 +427,32 @@ class Yamux(IMuxedConn):
raise e raise e
async def accept_stream(self) -> IMuxedStream: async def accept_stream(self) -> IMuxedStream:
logging.debug("Waiting for new stream") logger.debug("Waiting for new stream")
try: try:
stream = await self.new_stream_receive_channel.receive() stream = await self.new_stream_receive_channel.receive()
logging.debug(f"Received stream {stream.stream_id}") logger.debug(f"Received stream {stream.stream_id}")
return stream return stream
except trio.EndOfChannel: except trio.EndOfChannel:
raise MuxedStreamError("No new streams available") raise MuxedStreamError("No new streams available")
async def read_stream(self, stream_id: int, n: int = -1) -> bytes: async def read_stream(self, stream_id: int, n: int = -1) -> bytes:
logging.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}") logger.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
if n is None: if n is None:
n = -1 n = -1
while True: while True:
async with self.streams_lock: async with self.streams_lock:
if stream_id not in self.streams: if stream_id not in self.streams:
logging.debug(f"Stream {self.peer_id}:{stream_id} unknown") logger.debug(f"Stream {self.peer_id}:{stream_id} unknown")
raise MuxedStreamEOF("Stream closed") raise MuxedStreamEOF("Stream closed")
if self.event_shutting_down.is_set(): if self.event_shutting_down.is_set():
logging.debug( logger.debug(
f"Stream {self.peer_id}:{stream_id}: connection shutting down" f"Stream {self.peer_id}:{stream_id}: connection shutting down"
) )
raise MuxedStreamEOF("Connection shut down") raise MuxedStreamEOF("Connection shut down")
stream = self.streams[stream_id] stream = self.streams[stream_id]
buffer = self.stream_buffers.get(stream_id) buffer = self.stream_buffers.get(stream_id)
logging.debug( logger.debug(
f"Stream {self.peer_id}:{stream_id}: " f"Stream {self.peer_id}:{stream_id}: "
f"closed={stream.closed}, " f"closed={stream.closed}, "
f"recv_closed={stream.recv_closed}, " f"recv_closed={stream.recv_closed}, "
@ -457,7 +460,7 @@ class Yamux(IMuxedConn):
f"buffer_len={len(buffer) if buffer else 0}" f"buffer_len={len(buffer) if buffer else 0}"
) )
if buffer is None: if buffer is None:
logging.debug( logger.debug(
f"Stream {self.peer_id}:{stream_id}:" f"Stream {self.peer_id}:{stream_id}:"
f"Buffer gone, assuming closed" f"Buffer gone, assuming closed"
) )
@ -470,7 +473,7 @@ class Yamux(IMuxedConn):
else: else:
data = bytes(buffer[:n]) data = bytes(buffer[:n])
del buffer[:n] del buffer[:n]
logging.debug( logger.debug(
f"Returning {len(data)} bytes" f"Returning {len(data)} bytes"
f"from stream {self.peer_id}:{stream_id}, " f"from stream {self.peer_id}:{stream_id}, "
f"buffer_len={len(buffer)}" f"buffer_len={len(buffer)}"
@ -478,7 +481,7 @@ class Yamux(IMuxedConn):
return data return data
# If reset received and buffer is empty, raise reset # If reset received and buffer is empty, raise reset
if stream.reset_received: if stream.reset_received:
logging.debug( logger.debug(
f"Stream {self.peer_id}:{stream_id}:" f"Stream {self.peer_id}:{stream_id}:"
f"reset_received=True, raising MuxedStreamReset" f"reset_received=True, raising MuxedStreamReset"
) )
@ -491,7 +494,7 @@ class Yamux(IMuxedConn):
else: else:
data = bytes(buffer[:n]) data = bytes(buffer[:n])
del buffer[:n] del buffer[:n]
logging.debug( logger.debug(
f"Returning {len(data)} bytes" f"Returning {len(data)} bytes"
f"from stream {self.peer_id}:{stream_id}, " f"from stream {self.peer_id}:{stream_id}, "
f"buffer_len={len(buffer)}" f"buffer_len={len(buffer)}"
@ -499,21 +502,21 @@ class Yamux(IMuxedConn):
return data return data
# Check if stream is closed # Check if stream is closed
if stream.closed: if stream.closed:
logging.debug( logger.debug(
f"Stream {self.peer_id}:{stream_id}:" f"Stream {self.peer_id}:{stream_id}:"
f"closed=True, raising MuxedStreamReset" f"closed=True, raising MuxedStreamReset"
) )
raise MuxedStreamReset("Stream is reset or closed") raise MuxedStreamReset("Stream is reset or closed")
# Check if recv_closed and buffer empty # Check if recv_closed and buffer empty
if stream.recv_closed: if stream.recv_closed:
logging.debug( logger.debug(
f"Stream {self.peer_id}:{stream_id}:" f"Stream {self.peer_id}:{stream_id}:"
f"recv_closed=True, buffer empty, raising EOF" f"recv_closed=True, buffer empty, raising EOF"
) )
raise MuxedStreamEOF("Stream is closed for receiving") raise MuxedStreamEOF("Stream is closed for receiving")
# Wait for data if stream is still open # Wait for data if stream is still open
logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}") logger.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
try: try:
await self.stream_events[stream_id].wait() await self.stream_events[stream_id].wait()
self.stream_events[stream_id] = trio.Event() self.stream_events[stream_id] = trio.Event()
@ -528,7 +531,7 @@ class Yamux(IMuxedConn):
try: try:
header = await self.secured_conn.read(HEADER_SIZE) header = await self.secured_conn.read(HEADER_SIZE)
if not header or len(header) < HEADER_SIZE: if not header or len(header) < HEADER_SIZE:
logging.debug( logger.debug(
f"Connection closed orincomplete header for peer {self.peer_id}" f"Connection closed orincomplete header for peer {self.peer_id}"
) )
self.event_shutting_down.set() self.event_shutting_down.set()
@ -537,7 +540,7 @@ class Yamux(IMuxedConn):
version, typ, flags, stream_id, length = struct.unpack( version, typ, flags, stream_id, length = struct.unpack(
YAMUX_HEADER_FORMAT, header YAMUX_HEADER_FORMAT, header
) )
logging.debug( logger.debug(
f"Received header for peer {self.peer_id}:" f"Received header for peer {self.peer_id}:"
f"type={typ}, flags={flags}, stream_id={stream_id}," f"type={typ}, flags={flags}, stream_id={stream_id},"
f"length={length}" f"length={length}"
@ -558,7 +561,7 @@ class Yamux(IMuxedConn):
0, 0,
) )
await self.secured_conn.write(ack_header) await self.secured_conn.write(ack_header)
logging.debug( logger.debug(
f"Sending stream {stream_id}" f"Sending stream {stream_id}"
f"to channel for peer {self.peer_id}" f"to channel for peer {self.peer_id}"
) )
@ -576,7 +579,7 @@ class Yamux(IMuxedConn):
elif typ == TYPE_DATA and flags & FLAG_RST: elif typ == TYPE_DATA and flags & FLAG_RST:
async with self.streams_lock: async with self.streams_lock:
if stream_id in self.streams: if stream_id in self.streams:
logging.debug( logger.debug(
f"Resetting stream {stream_id} for peer {self.peer_id}" f"Resetting stream {stream_id} for peer {self.peer_id}"
) )
self.streams[stream_id].closed = True self.streams[stream_id].closed = True
@ -585,27 +588,27 @@ class Yamux(IMuxedConn):
elif typ == TYPE_DATA and flags & FLAG_ACK: elif typ == TYPE_DATA and flags & FLAG_ACK:
async with self.streams_lock: async with self.streams_lock:
if stream_id in self.streams: if stream_id in self.streams:
logging.debug( logger.debug(
f"Received ACK for stream" f"Received ACK for stream"
f"{stream_id} for peer {self.peer_id}" f"{stream_id} for peer {self.peer_id}"
) )
elif typ == TYPE_GO_AWAY: elif typ == TYPE_GO_AWAY:
error_code = length error_code = length
if error_code == GO_AWAY_NORMAL: if error_code == GO_AWAY_NORMAL:
logging.debug( logger.debug(
f"Received GO_AWAY for peer" f"Received GO_AWAY for peer"
f"{self.peer_id}: Normal termination" f"{self.peer_id}: Normal termination"
) )
elif error_code == GO_AWAY_PROTOCOL_ERROR: elif error_code == GO_AWAY_PROTOCOL_ERROR:
logging.error( logger.error(
f"Received GO_AWAY for peer{self.peer_id}: Protocol error" f"Received GO_AWAY for peer{self.peer_id}: Protocol error"
) )
elif error_code == GO_AWAY_INTERNAL_ERROR: elif error_code == GO_AWAY_INTERNAL_ERROR:
logging.error( logger.error(
f"Received GO_AWAY for peer {self.peer_id}: Internal error" f"Received GO_AWAY for peer {self.peer_id}: Internal error"
) )
else: else:
logging.error( logger.error(
f"Received GO_AWAY for peer {self.peer_id}" f"Received GO_AWAY for peer {self.peer_id}"
f"with unknown error code: {error_code}" f"with unknown error code: {error_code}"
) )
@ -614,7 +617,7 @@ class Yamux(IMuxedConn):
break break
elif typ == TYPE_PING: elif typ == TYPE_PING:
if flags & FLAG_SYN: if flags & FLAG_SYN:
logging.debug( logger.debug(
f"Received ping request with value" f"Received ping request with value"
f"{length} for peer {self.peer_id}" f"{length} for peer {self.peer_id}"
) )
@ -623,7 +626,7 @@ class Yamux(IMuxedConn):
) )
await self.secured_conn.write(ping_header) await self.secured_conn.write(ping_header)
elif flags & FLAG_ACK: elif flags & FLAG_ACK:
logging.debug( logger.debug(
f"Received ping response with value" f"Received ping response with value"
f"{length} for peer {self.peer_id}" f"{length} for peer {self.peer_id}"
) )
@ -637,7 +640,7 @@ class Yamux(IMuxedConn):
self.stream_buffers[stream_id].extend(data) self.stream_buffers[stream_id].extend(data)
self.stream_events[stream_id].set() self.stream_events[stream_id].set()
if flags & FLAG_FIN: if flags & FLAG_FIN:
logging.debug( logger.debug(
f"Received FIN for stream {self.peer_id}:" f"Received FIN for stream {self.peer_id}:"
f"{stream_id}, marking recv_closed" f"{stream_id}, marking recv_closed"
) )
@ -645,7 +648,7 @@ class Yamux(IMuxedConn):
if self.streams[stream_id].send_closed: if self.streams[stream_id].send_closed:
self.streams[stream_id].closed = True self.streams[stream_id].closed = True
except Exception as e: except Exception as e:
logging.error(f"Error reading data for stream {stream_id}: {e}") logger.error(f"Error reading data for stream {stream_id}: {e}")
# Mark stream as closed on read error # Mark stream as closed on read error
async with self.streams_lock: async with self.streams_lock:
if stream_id in self.streams: if stream_id in self.streams:
@ -659,7 +662,7 @@ class Yamux(IMuxedConn):
if stream_id in self.streams: if stream_id in self.streams:
stream = self.streams[stream_id] stream = self.streams[stream_id]
async with stream.window_lock: async with stream.window_lock:
logging.debug( logger.debug(
f"Received window update for stream" f"Received window update for stream"
f"{self.peer_id}:{stream_id}," f"{self.peer_id}:{stream_id},"
f" increment: {increment}" f" increment: {increment}"
@ -674,7 +677,7 @@ class Yamux(IMuxedConn):
and details.get("requested_count") == 2 and details.get("requested_count") == 2
and details.get("received_count") == 0 and details.get("received_count") == 0
): ):
logging.info( logger.info(
f"Stream closed cleanly for peer {self.peer_id}" f"Stream closed cleanly for peer {self.peer_id}"
+ f" (IncompleteReadError: {details})" + f" (IncompleteReadError: {details})"
) )
@ -682,15 +685,32 @@ class Yamux(IMuxedConn):
await self._cleanup_on_error() await self._cleanup_on_error()
break break
else: else:
logging.error( logger.error(
f"Error in handle_incoming for peer {self.peer_id}: " f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}" + f"{type(e).__name__}: {str(e)}"
) )
else: else:
logging.error( # Handle RawConnError with more nuance
f"Error in handle_incoming for peer {self.peer_id}: " if isinstance(e, RawConnError):
+ f"{type(e).__name__}: {str(e)}" error_msg = str(e)
) # If RawConnError is empty, it's likely normal cleanup
if not error_msg.strip():
logger.info(
f"RawConnError (empty) during cleanup for peer "
f"{self.peer_id} (normal connection shutdown)"
)
else:
# Log non-empty RawConnError as warning
logger.warning(
f"RawConnError during connection handling for peer "
f"{self.peer_id}: {error_msg}"
)
else:
# Log all other errors normally
logger.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
# Don't crash the whole connection for temporary errors # Don't crash the whole connection for temporary errors
if self.event_shutting_down.is_set() or isinstance( if self.event_shutting_down.is_set() or isinstance(
e, (RawConnError, OSError) e, (RawConnError, OSError)
@ -720,9 +740,9 @@ class Yamux(IMuxedConn):
# Close the secured connection # Close the secured connection
try: try:
await self.secured_conn.close() await self.secured_conn.close()
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}") logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
except Exception as close_error: except Exception as close_error:
logging.error( logger.error(
f"Error closing secured_conn for peer {self.peer_id}: {close_error}" f"Error closing secured_conn for peer {self.peer_id}: {close_error}"
) )
@ -731,14 +751,14 @@ class Yamux(IMuxedConn):
# Call on_close callback if provided # Call on_close callback if provided
if self.on_close: if self.on_close:
logging.debug(f"Calling on_close for peer {self.peer_id}") logger.debug(f"Calling on_close for peer {self.peer_id}")
try: try:
if inspect.iscoroutinefunction(self.on_close): if inspect.iscoroutinefunction(self.on_close):
await self.on_close() await self.on_close()
else: else:
self.on_close() self.on_close()
except Exception as callback_error: except Exception as callback_error:
logging.error(f"Error in on_close callback: {callback_error}") logger.error(f"Error in on_close callback: {callback_error}")
# Cancel nursery tasks # Cancel nursery tasks
if self._nursery: if self._nursery:

View File

@ -9,6 +9,7 @@ from libp2p.utils.varint import (
read_varint_prefixed_bytes, read_varint_prefixed_bytes,
decode_varint_from_bytes, decode_varint_from_bytes,
decode_varint_with_size, decode_varint_with_size,
read_length_prefixed_protobuf,
) )
from libp2p.utils.version import ( from libp2p.utils.version import (
get_agent_version, get_agent_version,
@ -24,4 +25,5 @@ __all__ = [
"read_varint_prefixed_bytes", "read_varint_prefixed_bytes",
"decode_varint_from_bytes", "decode_varint_from_bytes",
"decode_varint_with_size", "decode_varint_with_size",
"read_length_prefixed_protobuf",
] ]

View File

@ -1,7 +1,9 @@
import itertools import itertools
import logging import logging
import math import math
from typing import BinaryIO
from libp2p.abc import INetStream
from libp2p.exceptions import ( from libp2p.exceptions import (
ParseError, ParseError,
) )
@ -25,42 +27,41 @@ HIGH_MASK = 2**7
SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7 SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7
def encode_uvarint(number: int) -> bytes: def encode_uvarint(value: int) -> bytes:
"""Pack `number` into varint bytes.""" """Encode an unsigned integer as a varint."""
buf = b"" if value < 0:
while True: raise ValueError("Cannot encode negative value as uvarint")
towrite = number & 0x7F
number >>= 7 result = bytearray()
if number: while value >= 0x80:
buf += bytes((towrite | 0x80,)) result.append((value & 0x7F) | 0x80)
else: value >>= 7
buf += bytes((towrite,)) result.append(value & 0x7F)
return bytes(result)
def decode_uvarint(data: bytes) -> int:
"""Decode a varint from bytes."""
if not data:
raise ParseError("Unexpected end of data")
result = 0
shift = 0
for byte in data:
result |= (byte & 0x7F) << shift
if (byte & 0x80) == 0:
break break
return buf shift += 7
if shift >= 64:
raise ValueError("Varint too long")
return result
def decode_varint_from_bytes(data: bytes) -> int: def decode_varint_from_bytes(data: bytes) -> int:
""" """Decode a varint from bytes (alias for decode_uvarint for backward comp)."""
Decode a varint from bytes and return the value. return decode_uvarint(data)
This is a synchronous version of decode_uvarint_from_stream for already-read bytes.
"""
res = 0
for shift in itertools.count(0, 7):
if shift > SHIFT_64_BIT_MAX:
raise ParseError("Integer is too large...")
if not data:
raise ParseError("Unexpected end of data")
value = data[0]
data = data[1:]
res += (value & LOW_MASK) << shift
if not value & HIGH_MASK:
break
return res
async def decode_uvarint_from_stream(reader: Reader) -> int: async def decode_uvarint_from_stream(reader: Reader) -> int:
@ -84,34 +85,33 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
def decode_varint_with_size(data: bytes) -> tuple[int, int]: def decode_varint_with_size(data: bytes) -> tuple[int, int]:
""" """
Decode a varint from bytes and return (value, bytes_consumed). Decode a varint from bytes and return both the value and the number of bytes
Returns (0, 0) if the data doesn't start with a valid varint. consumed.
Returns:
Tuple[int, int]: (value, bytes_consumed)
""" """
try: result = 0
# Calculate how many bytes the varint consumes shift = 0
varint_size = 0 bytes_consumed = 0
for i, byte in enumerate(data):
varint_size += 1
if (byte & 0x80) == 0:
break
if varint_size == 0: for byte in data:
return 0, 0 result |= (byte & 0x7F) << shift
bytes_consumed += 1
if (byte & 0x80) == 0:
break
shift += 7
if shift >= 64:
raise ValueError("Varint too long")
# Extract just the varint bytes return result, bytes_consumed
varint_bytes = data[:varint_size]
# Decode the varint
value = decode_varint_from_bytes(varint_bytes)
return value, varint_size
except Exception:
return 0, 0
def encode_varint_prefixed(msg_bytes: bytes) -> bytes: def encode_varint_prefixed(data: bytes) -> bytes:
varint_len = encode_uvarint(len(msg_bytes)) """Encode data with a varint length prefix."""
return varint_len + msg_bytes length_bytes = encode_uvarint(len(data))
return length_bytes + data
async def read_varint_prefixed_bytes(reader: Reader) -> bytes: async def read_varint_prefixed_bytes(reader: Reader) -> bytes:
@ -138,3 +138,95 @@ async def read_delim(reader: Reader) -> bytes:
f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes!r}' f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes!r}'
) )
return msg_bytes[:-1] return msg_bytes[:-1]
def read_varint_prefixed_bytes_sync(
stream: BinaryIO, max_length: int = 1024 * 1024
) -> bytes:
"""
Read varint-prefixed bytes from a stream.
Args:
stream: A stream-like object with a read() method
max_length: Maximum allowed data length to prevent memory exhaustion
Returns:
bytes: The data without the length prefix
Raises:
ValueError: If the length prefix is invalid or too large
EOFError: If the stream ends unexpectedly
"""
# Read the varint length prefix
length_bytes = b""
while True:
byte_data = stream.read(1)
if not byte_data:
raise EOFError("Stream ended while reading varint length prefix")
length_bytes += byte_data
if byte_data[0] & 0x80 == 0:
break
# Decode the length
length = decode_uvarint(length_bytes)
if length > max_length:
raise ValueError(f"Data length {length} exceeds maximum allowed {max_length}")
# Read the data
data = stream.read(length)
if len(data) != length:
raise EOFError(f"Expected {length} bytes, got {len(data)}")
return data
async def read_length_prefixed_protobuf(
stream: INetStream, use_varint_format: bool = True, max_length: int = 1024 * 1024
) -> bytes:
"""Read a protobuf message from a stream, handling both formats."""
if use_varint_format:
# Read length-prefixed protobuf message from the stream
# First read the varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
raise Exception("No length prefix received")
length_bytes += b
if b[0] & 0x80 == 0:
break
msg_length = decode_varint_from_bytes(length_bytes)
if msg_length > max_length:
raise Exception(
f"Message length {msg_length} exceeds maximum allowed {max_length}"
)
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
raise Exception(
f"Incomplete message: expected {msg_length}, got {len(data)}"
)
return data
else:
# Read raw protobuf message from the stream
# For raw format, read all available data in one go
data = await stream.read()
# If we got no data, raise an exception
if not data:
raise Exception("No data received in raw format")
if len(data) > max_length:
raise Exception(
f"Message length {len(data)} exceeds maximum allowed {max_length}"
)
return data

View File

@ -0,0 +1 @@
Fixed incorrect handling of raw protobuf format in identify push protocol. The identify push example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats.

View File

@ -0,0 +1 @@
Yamux RawConnError Logging Refactor - Improved error handling and debug logging

View File

@ -0,0 +1,552 @@
import logging
import pytest
import trio
from libp2p.custom_types import TProtocol
from libp2p.identity.identify_push.identify_push import (
ID_PUSH,
identify_push_handler_for,
push_identify_to_peer,
push_identify_to_peers,
)
from tests.utils.factories import host_pair_factory
logger = logging.getLogger("libp2p.identity.identify-push-integration-test")
@pytest.mark.trio
async def test_identify_push_protocol_varint_format_integration(security_protocol):
"""Test identify/push protocol with varint format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to host_b so it has something to push
async def dummy_handler(stream):
pass
host_b.set_stream_handler(TProtocol("/test/protocol/1"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/2"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
)
# Push identify information from host_b to host_a
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that addresses were added
addrs = peerstore_a.addrs(peer_id_b)
assert len(addrs) > 0
# Check that protocols were added
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
# The protocols should include the dummy protocols we added
assert len(protocols) >= 2 # Should include the dummy protocols
@pytest.mark.trio
async def test_identify_push_protocol_raw_format_integration(security_protocol):
"""Test identify/push protocol with raw format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=False)
)
# Push identify information from host_b to host_a
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=False)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that addresses were added
addrs = peerstore_a.addrs(peer_id_b)
assert len(addrs) > 0
# Check that protocols were added
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) >= 1 # Should include the dummy protocol
@pytest.mark.trio
async def test_identify_push_default_format_behavior(security_protocol):
"""Test identify/push protocol uses correct default format."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Use default identify/push handler (should use varint format)
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Push identify information from host_b to host_a
await push_identify_to_peer(host_b, host_a.get_id())
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that protocols were added
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) >= 1 # Should include the dummy protocol
@pytest.mark.trio
async def test_identify_push_cross_format_compatibility_varint_to_raw(
security_protocol,
):
"""Test varint pusher with raw listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Use an event to signal when handler is ready
handler_ready = trio.Event()
# Create a wrapper handler that signals when ready
original_handler = identify_push_handler_for(host_a, use_varint_format=False)
async def wrapped_handler(stream):
handler_ready.set() # Signal that handler is ready
await original_handler(stream)
# Host A uses raw format with wrapped handler
host_a.set_stream_handler(ID_PUSH, wrapped_handler)
# Host B pushes with varint format (should fail gracefully)
success = await push_identify_to_peer(
host_b, host_a.get_id(), use_varint_format=True
)
# This should fail due to format mismatch
# Note: The format detection might be more robust than expected
# so we just check that the operation completes
assert isinstance(success, bool)
@pytest.mark.trio
async def test_identify_push_cross_format_compatibility_raw_to_varint(
security_protocol,
):
"""Test raw pusher with varint listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Use an event to signal when handler is ready
handler_ready = trio.Event()
# Create a wrapper handler that signals when ready
original_handler = identify_push_handler_for(host_a, use_varint_format=True)
async def wrapped_handler(stream):
handler_ready.set() # Signal that handler is ready
await original_handler(stream)
# Host A uses varint format with wrapped handler
host_a.set_stream_handler(ID_PUSH, wrapped_handler)
# Host B pushes with raw format (should fail gracefully)
success = await push_identify_to_peer(
host_b, host_a.get_id(), use_varint_format=False
)
# This should fail due to format mismatch
# Note: The format detection might be more robust than expected
# so we just check that the operation completes
assert isinstance(success, bool)
@pytest.mark.trio
async def test_identify_push_multiple_peers_integration(security_protocol):
"""Test identify/push protocol with multiple peers."""
# Create two hosts using the factory
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create a third host following the pattern from test_identify_push.py
import multiaddr
from libp2p import new_host
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.peer.peerinfo import info_from_p2p_addr
# Create a new key pair for host_c
key_pair_c = create_new_key_pair()
host_c = new_host(key_pair=key_pair_c)
# Set up identify/push handlers on all hosts
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
host_b.set_stream_handler(ID_PUSH, identify_push_handler_for(host_b))
host_c.set_stream_handler(ID_PUSH, identify_push_handler_for(host_c))
# Start listening on a random port using the run context manager
listen_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
async with host_c.run([listen_addr]):
# Connect host_c to host_a and host_b using the correct pattern
await host_c.connect(info_from_p2p_addr(host_a.get_addrs()[0]))
await host_c.connect(info_from_p2p_addr(host_b.get_addrs()[0]))
# Push identify information from host_a to all connected peers
await push_identify_to_peers(host_a)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Check that host_b's peerstore has been updated with host_a's information
peerstore_b = host_b.get_peerstore()
peer_id_a = host_a.get_id()
# Check that the peer is in the peerstore
assert peer_id_a in peerstore_b.peer_ids()
# Check that host_c's peerstore has been updated with host_a's information
peerstore_c = host_c.get_peerstore()
# Check that the peer is in the peerstore
assert peer_id_a in peerstore_c.peer_ids()
# Test for push_identify to only connected peers and not all peers
# Disconnect a from c.
await host_c.disconnect(host_a.get_id())
await push_identify_to_peers(host_c)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Check that host_a's peerstore has not been updated with host_c's info
assert host_c.get_id() not in host_a.get_peerstore().peer_ids()
# Check that host_b's peerstore has been updated with host_c's info
assert host_c.get_id() in host_b.get_peerstore().peer_ids()
@pytest.mark.trio
async def test_identify_push_large_message_handling(security_protocol):
"""Test identify/push protocol handles large messages with many protocols."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add many protocols to make the message larger
async def dummy_handler(stream):
pass
for i in range(10):
host_b.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler)
# Also add some protocols to host_a to ensure it has protocols to push
for i in range(5):
host_a.set_stream_handler(TProtocol(f"/test/protocol/a{i}"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
)
# Push identify information from host_b to host_a
success = await push_identify_to_peer(
host_b, host_a.get_id(), use_varint_format=True
)
assert success
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated with all protocols
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) >= 10 # Should include the dummy protocols
@pytest.mark.trio
async def test_identify_push_peerstore_update_completeness(security_protocol):
"""Test that identify/push updates all relevant peerstore information."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Push identify information from host_b to host_a
await push_identify_to_peer(host_b, host_a.get_id())
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that protocols were added
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) > 0
# Check that addresses were added
addrs = peerstore_a.addrs(peer_id_b)
assert len(addrs) > 0
# Check that public key was added
pubkey = peerstore_a.pubkey(peer_id_b)
assert pubkey is not None
assert pubkey.serialize() == host_b.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_push_concurrent_requests(security_protocol):
"""Test identify/push protocol handles concurrent requests properly."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Make multiple concurrent push requests
results = []
async def push_identify():
result = await push_identify_to_peer(host_b, host_a.get_id())
results.append(result)
# Run multiple concurrent pushes using nursery
async with trio.open_nursery() as nursery:
for _ in range(3):
nursery.start_soon(push_identify)
# All should succeed
assert len(results) == 3
assert all(results)
# Wait a bit for the pushes to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) > 0
@pytest.mark.trio
async def test_identify_push_stream_handling(security_protocol):
"""Test identify/push protocol properly handles stream lifecycle."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Push identify information from host_b to host_a
success = await push_identify_to_peer(host_b, host_a.get_id())
assert success
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) > 0
@pytest.mark.trio
async def test_identify_push_error_handling(security_protocol):
"""Test identify/push protocol handles errors gracefully."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create a handler that raises an exception but catches it to prevent test
# failure
async def error_handler(stream):
try:
await stream.close()
raise Exception("Test error")
except Exception:
# Catch the exception to prevent it from propagating up
pass
host_a.set_stream_handler(ID_PUSH, error_handler)
# Push should complete (message sent) but handler should fail gracefully
success = await push_identify_to_peer(host_b, host_a.get_id())
assert success # The push operation itself succeeds (message sent)
# Wait a bit for the handler to process
await trio.sleep(0.1)
# Verify that the error was handled gracefully (no test failure)
# The handler caught the exception and didn't propagate it
@pytest.mark.trio
async def test_identify_push_message_equivalence_real_network(security_protocol):
"""Test that both formats produce equivalent peerstore updates in real network."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Test varint format
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
)
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Get peerstore state after varint push
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
protocols_varint = peerstore_a.get_protocols(peer_id_b)
addrs_varint = peerstore_a.addrs(peer_id_b)
# Clear peerstore for next test
peerstore_a.clear_addrs(peer_id_b)
peerstore_a.clear_protocol_data(peer_id_b)
# Test raw format
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=False)
)
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=False)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Get peerstore state after raw push
protocols_raw = peerstore_a.get_protocols(peer_id_b)
addrs_raw = peerstore_a.addrs(peer_id_b)
# Both should produce equivalent peerstore updates
# Check that both formats successfully updated protocols
assert protocols_varint is not None
assert protocols_raw is not None
assert len(protocols_varint) > 0
assert len(protocols_raw) > 0
# Check that both formats successfully updated addresses
assert addrs_varint is not None
assert addrs_raw is not None
assert len(addrs_varint) > 0
assert len(addrs_raw) > 0
# Both should contain the same essential information
# (exact address lists might differ due to format-specific handling)
assert set(protocols_varint) == set(protocols_raw)
@pytest.mark.trio
async def test_identify_push_with_observed_address(security_protocol):
"""Test identify/push protocol includes observed address information."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Get host_b's address as observed by host_a
from multiaddr import Multiaddr
host_b_addr = host_b.get_addrs()[0]
observed_multiaddr = Multiaddr(str(host_b_addr))
# Push identify information with observed address
await push_identify_to_peer(
host_b, host_a.get_id(), observed_multiaddr=observed_multiaddr
)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that addresses were added
addrs = peerstore_a.addrs(peer_id_b)
assert len(addrs) > 0
# The observed address should be among the stored addresses
addr_strings = [str(addr) for addr in addrs]
assert str(observed_multiaddr) in addr_strings