mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'main' into py-multiaddr
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import base64
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
@ -112,7 +113,12 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
|
||||
# Replace the handler with our custom one
|
||||
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler)
|
||||
|
||||
await trio.sleep_forever()
|
||||
try:
|
||||
await trio.sleep_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("\n🛑 Shutting down listener...")
|
||||
logger.info("Listener interrupted by user")
|
||||
return
|
||||
|
||||
else:
|
||||
# Create second host (dialer)
|
||||
@ -147,38 +153,13 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
|
||||
try:
|
||||
print("Starting identify protocol...")
|
||||
|
||||
# Read the response properly based on the format
|
||||
if use_varint_format:
|
||||
# For length-prefixed format, read varint length first
|
||||
from libp2p.utils.varint import decode_varint_from_bytes
|
||||
# Read the response using the utility function
|
||||
from libp2p.utils.varint import read_length_prefixed_protobuf
|
||||
|
||||
# Read varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
raise Exception("Stream closed while reading varint length")
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
print(f"Expected message length: {msg_length} bytes")
|
||||
|
||||
# Read the protobuf message
|
||||
response = await stream.read(msg_length)
|
||||
if len(response) != msg_length:
|
||||
raise Exception(
|
||||
f"Incomplete message: expected {msg_length} bytes, "
|
||||
f"got {len(response)}"
|
||||
)
|
||||
|
||||
# Combine length prefix and message
|
||||
full_response = length_bytes + response
|
||||
else:
|
||||
# For raw format, read all available data
|
||||
response = await stream.read(8192)
|
||||
full_response = response
|
||||
response = await read_length_prefixed_protobuf(
|
||||
stream, use_varint_format
|
||||
)
|
||||
full_response = response
|
||||
|
||||
await stream.close()
|
||||
|
||||
@ -254,6 +235,7 @@ def main() -> None:
|
||||
"length-prefixed (new format)"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine format: raw format if --raw-format is specified, otherwise
|
||||
@ -261,9 +243,19 @@ def main() -> None:
|
||||
use_varint_format = not args.raw_format
|
||||
|
||||
try:
|
||||
trio.run(run, *(args.port, args.destination, use_varint_format))
|
||||
if args.destination:
|
||||
# Run in dialer mode
|
||||
trio.run(run, *(args.port, args.destination, use_varint_format))
|
||||
else:
|
||||
# Run in listener mode
|
||||
trio.run(run, *(args.port, args.destination, use_varint_format))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
print("\n👋 Goodbye!")
|
||||
logger.info("Application interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {str(e)}")
|
||||
logger.error("Error: %s", str(e))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -11,23 +11,26 @@ This example shows how to:
|
||||
|
||||
import logging
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import (
|
||||
new_host,
|
||||
)
|
||||
from libp2p.abc import (
|
||||
INetStream,
|
||||
)
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.identity.identify import (
|
||||
identify_handler_for,
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
from libp2p.identity.identify_push import (
|
||||
ID_PUSH,
|
||||
identify_push_handler_for,
|
||||
push_identify_to_peer,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
@ -38,8 +41,145 @@ from libp2p.peer.peerinfo import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_custom_identify_handler(host, host_name: str):
|
||||
"""Create a custom identify handler that displays received information."""
|
||||
|
||||
async def handle_identify(stream: INetStream) -> None:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
print(f"\n🔍 {host_name} received identify request from peer: {peer_id}")
|
||||
|
||||
# Get the standard identify response using the existing function
|
||||
from libp2p.identity.identify.identify import (
|
||||
_mk_identify_protobuf,
|
||||
_remote_address_to_multiaddr,
|
||||
)
|
||||
|
||||
# Get observed address
|
||||
observed_multiaddr = None
|
||||
try:
|
||||
remote_address = stream.get_remote_address()
|
||||
if remote_address:
|
||||
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build the identify protobuf
|
||||
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
|
||||
response_data = identify_msg.SerializeToString()
|
||||
|
||||
print(f" 📋 {host_name} identify information:")
|
||||
if identify_msg.HasField("protocol_version"):
|
||||
print(f" Protocol Version: {identify_msg.protocol_version}")
|
||||
if identify_msg.HasField("agent_version"):
|
||||
print(f" Agent Version: {identify_msg.agent_version}")
|
||||
if identify_msg.HasField("public_key"):
|
||||
print(f" Public Key: {identify_msg.public_key.hex()[:16]}...")
|
||||
if identify_msg.listen_addrs:
|
||||
print(" Listen Addresses:")
|
||||
for addr_bytes in identify_msg.listen_addrs:
|
||||
addr = multiaddr.Multiaddr(addr_bytes)
|
||||
print(f" - {addr}")
|
||||
if identify_msg.protocols:
|
||||
print(" Supported Protocols:")
|
||||
for protocol in identify_msg.protocols:
|
||||
print(f" - {protocol}")
|
||||
|
||||
# Send the response
|
||||
await stream.write(response_data)
|
||||
await stream.close()
|
||||
|
||||
return handle_identify
|
||||
|
||||
|
||||
def create_custom_identify_push_handler(host, host_name: str):
|
||||
"""Create a custom identify/push handler that displays received information."""
|
||||
|
||||
async def handle_identify_push(stream: INetStream) -> None:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
print(f"\n📤 {host_name} received identify/push from peer: {peer_id}")
|
||||
|
||||
try:
|
||||
# Read the identify message using the utility function
|
||||
from libp2p.utils.varint import read_length_prefixed_protobuf
|
||||
|
||||
data = await read_length_prefixed_protobuf(stream, use_varint_format=True)
|
||||
|
||||
# Parse the identify message
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
|
||||
print(" 📋 Received identify information:")
|
||||
if identify_msg.HasField("protocol_version"):
|
||||
print(f" Protocol Version: {identify_msg.protocol_version}")
|
||||
if identify_msg.HasField("agent_version"):
|
||||
print(f" Agent Version: {identify_msg.agent_version}")
|
||||
if identify_msg.HasField("public_key"):
|
||||
print(f" Public Key: {identify_msg.public_key.hex()[:16]}...")
|
||||
if identify_msg.HasField("observed_addr") and identify_msg.observed_addr:
|
||||
observed_addr = multiaddr.Multiaddr(identify_msg.observed_addr)
|
||||
print(f" Observed Address: {observed_addr}")
|
||||
if identify_msg.listen_addrs:
|
||||
print(" Listen Addresses:")
|
||||
for addr_bytes in identify_msg.listen_addrs:
|
||||
addr = multiaddr.Multiaddr(addr_bytes)
|
||||
print(f" - {addr}")
|
||||
if identify_msg.protocols:
|
||||
print(" Supported Protocols:")
|
||||
for protocol in identify_msg.protocols:
|
||||
print(f" - {protocol}")
|
||||
|
||||
# Update the peerstore with the new information
|
||||
from libp2p.identity.identify_push.identify_push import (
|
||||
_update_peerstore_from_identify,
|
||||
)
|
||||
|
||||
await _update_peerstore_from_identify(
|
||||
host.get_peerstore(), peer_id, identify_msg
|
||||
)
|
||||
|
||||
print(f" ✅ {host_name} updated peerstore with new information")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Error processing identify/push: {e}")
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
return handle_identify_push
|
||||
|
||||
|
||||
async def display_peerstore_info(host, host_name: str, peer_id, description: str):
|
||||
"""Display peerstore information for a specific peer."""
|
||||
peerstore = host.get_peerstore()
|
||||
|
||||
try:
|
||||
addrs = peerstore.addrs(peer_id)
|
||||
except Exception:
|
||||
addrs = []
|
||||
|
||||
try:
|
||||
protocols = peerstore.get_protocols(peer_id)
|
||||
except Exception:
|
||||
protocols = []
|
||||
|
||||
print(f"\n📚 {host_name} peerstore for {description}:")
|
||||
print(f" Peer ID: {peer_id}")
|
||||
if addrs:
|
||||
print(" Addresses:")
|
||||
for addr in addrs:
|
||||
print(f" - {addr}")
|
||||
else:
|
||||
print(" Addresses: None")
|
||||
|
||||
if protocols:
|
||||
print(" Protocols:")
|
||||
for protocol in protocols:
|
||||
print(f" - {protocol}")
|
||||
else:
|
||||
print(" Protocols: None")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
print("\n==== Starting Identify-Push Example ====\n")
|
||||
print("\n==== Starting Enhanced Identify-Push Example ====\n")
|
||||
|
||||
# Create key pairs for the two hosts
|
||||
key_pair_1 = create_new_key_pair()
|
||||
@ -48,45 +188,49 @@ async def main() -> None:
|
||||
# Create the first host
|
||||
host_1 = new_host(key_pair=key_pair_1)
|
||||
|
||||
# Set up the identify and identify/push handlers
|
||||
host_1.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_1))
|
||||
host_1.set_stream_handler(ID_PUSH, identify_push_handler_for(host_1))
|
||||
# Set up custom identify and identify/push handlers
|
||||
host_1.set_stream_handler(
|
||||
TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_1, "Host 1")
|
||||
)
|
||||
host_1.set_stream_handler(
|
||||
ID_PUSH, create_custom_identify_push_handler(host_1, "Host 1")
|
||||
)
|
||||
|
||||
# Create the second host
|
||||
host_2 = new_host(key_pair=key_pair_2)
|
||||
|
||||
# Set up the identify and identify/push handlers
|
||||
host_2.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_2))
|
||||
host_2.set_stream_handler(ID_PUSH, identify_push_handler_for(host_2))
|
||||
# Set up custom identify and identify/push handlers
|
||||
host_2.set_stream_handler(
|
||||
TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_2, "Host 2")
|
||||
)
|
||||
host_2.set_stream_handler(
|
||||
ID_PUSH, create_custom_identify_push_handler(host_2, "Host 2")
|
||||
)
|
||||
|
||||
# Start listening on random ports using the run context manager
|
||||
import multiaddr
|
||||
|
||||
listen_addr_1 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
listen_addr_2 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
|
||||
async with host_1.run([listen_addr_1]), host_2.run([listen_addr_2]):
|
||||
# Get the addresses of both hosts
|
||||
addr_1 = host_1.get_addrs()[0]
|
||||
logger.info(f"Host 1 listening on {addr_1}")
|
||||
print(f"Host 1 listening on {addr_1}")
|
||||
print(f"Peer ID: {host_1.get_id().pretty()}")
|
||||
|
||||
addr_2 = host_2.get_addrs()[0]
|
||||
logger.info(f"Host 2 listening on {addr_2}")
|
||||
print(f"Host 2 listening on {addr_2}")
|
||||
print(f"Peer ID: {host_2.get_id().pretty()}")
|
||||
|
||||
print("\nConnecting Host 2 to Host 1...")
|
||||
print("🏠 Host Configuration:")
|
||||
print(f" Host 1: {addr_1}")
|
||||
print(f" Host 1 Peer ID: {host_1.get_id().pretty()}")
|
||||
print(f" Host 2: {addr_2}")
|
||||
print(f" Host 2 Peer ID: {host_2.get_id().pretty()}")
|
||||
|
||||
print("\n🔗 Connecting Host 2 to Host 1...")
|
||||
|
||||
# Connect host_2 to host_1
|
||||
peer_info = info_from_p2p_addr(addr_1)
|
||||
await host_2.connect(peer_info)
|
||||
logger.info("Host 2 connected to Host 1")
|
||||
print("Host 2 successfully connected to Host 1")
|
||||
print("✅ Host 2 successfully connected to Host 1")
|
||||
|
||||
# Run the identify protocol from host_2 to host_1
|
||||
# (so Host 1 learns Host 2's address)
|
||||
print("\n🔄 Running identify protocol (Host 2 → Host 1)...")
|
||||
from libp2p.identity.identify.identify import ID as IDENTIFY_PROTOCOL_ID
|
||||
|
||||
stream = await host_2.new_stream(host_1.get_id(), (IDENTIFY_PROTOCOL_ID,))
|
||||
@ -94,64 +238,58 @@ async def main() -> None:
|
||||
await stream.close()
|
||||
|
||||
# Run the identify protocol from host_1 to host_2
|
||||
# (so Host 2 learns Host 1's address)
|
||||
print("\n🔄 Running identify protocol (Host 1 → Host 2)...")
|
||||
stream = await host_1.new_stream(host_2.get_id(), (IDENTIFY_PROTOCOL_ID,))
|
||||
response = await stream.read()
|
||||
await stream.close()
|
||||
|
||||
# --- NEW CODE: Update Host 1's peerstore with Host 2's addresses ---
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
|
||||
# Update Host 1's peerstore with Host 2's addresses
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(response)
|
||||
peerstore_1 = host_1.get_peerstore()
|
||||
peer_id_2 = host_2.get_id()
|
||||
for addr_bytes in identify_msg.listen_addrs:
|
||||
maddr = multiaddr.Multiaddr(addr_bytes)
|
||||
# TTL can be any positive int
|
||||
peerstore_1.add_addr(
|
||||
peer_id_2,
|
||||
maddr,
|
||||
ttl=3600,
|
||||
)
|
||||
# --- END NEW CODE ---
|
||||
peerstore_1.add_addr(peer_id_2, maddr, ttl=3600)
|
||||
|
||||
# Now Host 1's peerstore should have Host 2's address
|
||||
peerstore_1 = host_1.get_peerstore()
|
||||
peer_id_2 = host_2.get_id()
|
||||
addrs_1_for_2 = peerstore_1.addrs(peer_id_2)
|
||||
logger.info(
|
||||
f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: "
|
||||
f"{addrs_1_for_2}"
|
||||
)
|
||||
print(
|
||||
f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: "
|
||||
f"{addrs_1_for_2}"
|
||||
# Display peerstore information before push
|
||||
await display_peerstore_info(
|
||||
host_1, "Host 1", peer_id_2, "Host 2 (before push)"
|
||||
)
|
||||
|
||||
# Push identify information from host_1 to host_2
|
||||
logger.info("Host 1 pushing identify information to Host 2")
|
||||
print("\nHost 1 pushing identify information to Host 2...")
|
||||
print("\n📤 Host 1 pushing identify information to Host 2...")
|
||||
|
||||
try:
|
||||
# Call push_identify_to_peer which now returns a boolean
|
||||
success = await push_identify_to_peer(host_1, host_2.get_id())
|
||||
|
||||
if success:
|
||||
logger.info("Identify push completed successfully")
|
||||
print("Identify push completed successfully!")
|
||||
print("✅ Identify push completed successfully!")
|
||||
else:
|
||||
logger.warning("Identify push didn't complete successfully")
|
||||
print("\nWarning: Identify push didn't complete successfully")
|
||||
print("⚠️ Identify push didn't complete successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during identify push: {str(e)}")
|
||||
print(f"\nError during identify push: {str(e)}")
|
||||
print(f"❌ Error during identify push: {str(e)}")
|
||||
|
||||
# Add this at the end of your async with block:
|
||||
await trio.sleep(0.5) # Give background tasks time to finish
|
||||
# Give a moment for the identify/push processing to complete
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# Display peerstore information after push
|
||||
await display_peerstore_info(host_1, "Host 1", peer_id_2, "Host 2 (after push)")
|
||||
await display_peerstore_info(
|
||||
host_2, "Host 2", host_1.get_id(), "Host 1 (after push)"
|
||||
)
|
||||
|
||||
# Give more time for background tasks to finish and connections to stabilize
|
||||
print("\n⏳ Waiting for background tasks to complete...")
|
||||
await trio.sleep(1.0)
|
||||
|
||||
# Gracefully close connections to prevent connection errors
|
||||
print("🔌 Closing connections...")
|
||||
await host_2.disconnect(host_1.get_id())
|
||||
await trio.sleep(0.2)
|
||||
|
||||
print("\n🎉 Example completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -41,6 +41,9 @@ from libp2p.identity.identify import (
|
||||
ID as ID_IDENTIFY,
|
||||
identify_handler_for,
|
||||
)
|
||||
from libp2p.identity.identify.identify import (
|
||||
_remote_address_to_multiaddr,
|
||||
)
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
@ -72,40 +75,30 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
|
||||
async def handle_identify_push(stream: INetStream) -> None:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
# Get remote address information
|
||||
try:
|
||||
if use_varint_format:
|
||||
# Read length-prefixed identify message from the stream
|
||||
from libp2p.utils.varint import decode_varint_from_bytes
|
||||
remote_address = stream.get_remote_address()
|
||||
if remote_address:
|
||||
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
|
||||
logger.info(
|
||||
"Connection from remote peer %s, address: %s, multiaddr: %s",
|
||||
peer_id,
|
||||
remote_address,
|
||||
observed_multiaddr,
|
||||
)
|
||||
print(f"\n🔗 Received identify/push request from peer: {peer_id}")
|
||||
# Add the peer ID to create a complete multiaddr
|
||||
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
|
||||
print(f" Remote address: {complete_multiaddr}")
|
||||
except Exception as e:
|
||||
logger.error("Error getting remote address: %s", e)
|
||||
print(f"\n🔗 Received identify/push request from peer: {peer_id}")
|
||||
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
break
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
try:
|
||||
# Use the utility function to read the protobuf message
|
||||
from libp2p.utils.varint import read_length_prefixed_protobuf
|
||||
|
||||
if not length_bytes:
|
||||
logger.warning("No length prefix received from peer %s", peer_id)
|
||||
return
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
logger.warning("Incomplete message received from peer %s", peer_id)
|
||||
return
|
||||
else:
|
||||
# Read raw protobuf message from the stream
|
||||
data = b""
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
data += chunk
|
||||
data = await read_length_prefixed_protobuf(stream, use_varint_format)
|
||||
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
@ -155,11 +148,41 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
|
||||
await _update_peerstore_from_identify(peerstore, peer_id, identify_msg)
|
||||
|
||||
logger.info("Successfully processed identify/push from peer %s", peer_id)
|
||||
print(f"\nSuccessfully processed identify/push from peer {peer_id}")
|
||||
print(f"✅ Successfully processed identify/push from peer {peer_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing identify/push from %s: %s", peer_id, e)
|
||||
print(f"\nError processing identify/push from {peer_id}: {e}")
|
||||
error_msg = str(e)
|
||||
logger.error(
|
||||
"Error processing identify/push from %s: %s", peer_id, error_msg
|
||||
)
|
||||
print(f"\nError processing identify/push from {peer_id}: {error_msg}")
|
||||
|
||||
# Check for specific format mismatch errors
|
||||
if (
|
||||
"Error parsing message" in error_msg
|
||||
or "DecodeError" in error_msg
|
||||
or "ParseFromString" in error_msg
|
||||
):
|
||||
print("\n" + "=" * 60)
|
||||
print("FORMAT MISMATCH DETECTED!")
|
||||
print("=" * 60)
|
||||
if use_varint_format:
|
||||
print(
|
||||
"You are using length-prefixed format (default) but the "
|
||||
"dialer is using raw protobuf format."
|
||||
)
|
||||
print("\nTo fix this, run the dialer with the --raw-format flag:")
|
||||
print(
|
||||
"identify-push-listener-dialer-demo --raw-format -d <ADDRESS>"
|
||||
)
|
||||
else:
|
||||
print("You are using raw protobuf format but the dialer")
|
||||
print("is using length-prefixed format (default).")
|
||||
print(
|
||||
"\nTo fix this, run the dialer without the --raw-format flag:"
|
||||
)
|
||||
print("identify-push-listener-dialer-demo -d <ADDRESS>")
|
||||
print("=" * 60)
|
||||
finally:
|
||||
# Close the stream after processing
|
||||
await stream.close()
|
||||
@ -167,7 +190,9 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
|
||||
return handle_identify_push
|
||||
|
||||
|
||||
async def run_listener(port: int, use_varint_format: bool = True) -> None:
|
||||
async def run_listener(
|
||||
port: int, use_varint_format: bool = True, raw_format_flag: bool = False
|
||||
) -> None:
|
||||
"""Run a host in listener mode."""
|
||||
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
|
||||
print(
|
||||
@ -187,29 +212,41 @@ async def run_listener(port: int, use_varint_format: bool = True) -> None:
|
||||
)
|
||||
host.set_stream_handler(
|
||||
ID_IDENTIFY_PUSH,
|
||||
identify_push_handler_for(host, use_varint_format=use_varint_format),
|
||||
custom_identify_push_handler_for(host, use_varint_format=use_varint_format),
|
||||
)
|
||||
|
||||
# Start listening
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
|
||||
async with host.run([listen_addr]):
|
||||
addr = host.get_addrs()[0]
|
||||
logger.info("Listener host ready!")
|
||||
print("Listener host ready!")
|
||||
try:
|
||||
async with host.run([listen_addr]):
|
||||
addr = host.get_addrs()[0]
|
||||
logger.info("Listener host ready!")
|
||||
print("Listener host ready!")
|
||||
|
||||
logger.info(f"Listening on: {addr}")
|
||||
print(f"Listening on: {addr}")
|
||||
logger.info(f"Listening on: {addr}")
|
||||
print(f"Listening on: {addr}")
|
||||
|
||||
logger.info(f"Peer ID: {host.get_id().pretty()}")
|
||||
print(f"Peer ID: {host.get_id().pretty()}")
|
||||
logger.info(f"Peer ID: {host.get_id().pretty()}")
|
||||
print(f"Peer ID: {host.get_id().pretty()}")
|
||||
|
||||
print("\nRun dialer with command:")
|
||||
print(f"identify-push-listener-dialer-demo -d {addr}")
|
||||
print("\nWaiting for incoming connections... (Ctrl+C to exit)")
|
||||
print("\nRun dialer with command:")
|
||||
if raw_format_flag:
|
||||
print(f"identify-push-listener-dialer-demo -d {addr} --raw-format")
|
||||
else:
|
||||
print(f"identify-push-listener-dialer-demo -d {addr}")
|
||||
print("\nWaiting for incoming identify/push requests... (Ctrl+C to exit)")
|
||||
|
||||
# Keep running until interrupted
|
||||
await trio.sleep_forever()
|
||||
# Keep running until interrupted
|
||||
try:
|
||||
await trio.sleep_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("\n🛑 Shutting down listener...")
|
||||
logger.info("Listener interrupted by user")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Listener error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def run_dialer(
|
||||
@ -256,7 +293,9 @@ async def run_dialer(
|
||||
try:
|
||||
await host.connect(peer_info)
|
||||
logger.info("Successfully connected to listener!")
|
||||
print("Successfully connected to listener!")
|
||||
print("✅ Successfully connected to listener!")
|
||||
print(f" Connected to: {peer_info.peer_id}")
|
||||
print(f" Full address: {destination}")
|
||||
|
||||
# Push identify information to the listener
|
||||
logger.info("Pushing identify information to listener...")
|
||||
@ -270,7 +309,7 @@ async def run_dialer(
|
||||
|
||||
if success:
|
||||
logger.info("Identify push completed successfully!")
|
||||
print("Identify push completed successfully!")
|
||||
print("✅ Identify push completed successfully!")
|
||||
|
||||
logger.info("Example completed successfully!")
|
||||
print("\nExample completed successfully!")
|
||||
@ -281,17 +320,57 @@ async def run_dialer(
|
||||
logger.warning("Example completed with warnings.")
|
||||
print("Example completed with warnings.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during identify push: {str(e)}")
|
||||
print(f"\nError during identify push: {str(e)}")
|
||||
error_msg = str(e)
|
||||
logger.error(f"Error during identify push: {error_msg}")
|
||||
print(f"\nError during identify push: {error_msg}")
|
||||
|
||||
# Check for specific format mismatch errors
|
||||
if (
|
||||
"Error parsing message" in error_msg
|
||||
or "DecodeError" in error_msg
|
||||
or "ParseFromString" in error_msg
|
||||
):
|
||||
print("\n" + "=" * 60)
|
||||
print("FORMAT MISMATCH DETECTED!")
|
||||
print("=" * 60)
|
||||
if use_varint_format:
|
||||
print(
|
||||
"You are using length-prefixed format (default) but the "
|
||||
"listener is using raw protobuf format."
|
||||
)
|
||||
print(
|
||||
"\nTo fix this, run the dialer with the --raw-format flag:"
|
||||
)
|
||||
print(
|
||||
f"identify-push-listener-dialer-demo --raw-format -d "
|
||||
f"{destination}"
|
||||
)
|
||||
else:
|
||||
print("You are using raw protobuf format but the listener")
|
||||
print("is using length-prefixed format (default).")
|
||||
print(
|
||||
"\nTo fix this, run the dialer without the --raw-format "
|
||||
"flag:"
|
||||
)
|
||||
print(f"identify-push-listener-dialer-demo -d {destination}")
|
||||
print("=" * 60)
|
||||
|
||||
logger.error("Example completed with errors.")
|
||||
print("Example completed with errors.")
|
||||
# Continue execution despite the push error
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during dialer operation: {str(e)}")
|
||||
print(f"\nError during dialer operation: {str(e)}")
|
||||
raise
|
||||
error_msg = str(e)
|
||||
if "unable to connect" in error_msg or "SwarmException" in error_msg:
|
||||
print(f"\n❌ Cannot connect to peer: {peer_info.peer_id}")
|
||||
print(f" Address: {destination}")
|
||||
print(f" Error: {error_msg}")
|
||||
print("\n💡 Make sure the peer is running and the address is correct.")
|
||||
return
|
||||
else:
|
||||
logger.error(f"Error during dialer operation: {error_msg}")
|
||||
print(f"\nError during dialer operation: {error_msg}")
|
||||
raise
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@ -301,12 +380,21 @@ def main() -> None:
|
||||
Without arguments, it runs as a listener on random port.
|
||||
With -d parameter, it runs as a dialer on random port.
|
||||
|
||||
Port 0 (default) means the OS will automatically assign an available port.
|
||||
This prevents port conflicts when running multiple instances.
|
||||
|
||||
Use --raw-format to send raw protobuf messages (old format) instead of
|
||||
length-prefixed protobuf messages (new format, default).
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--port",
|
||||
default=0,
|
||||
type=int,
|
||||
help="source port number (0 = random available port)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--destination",
|
||||
@ -321,6 +409,7 @@ def main() -> None:
|
||||
"length-prefixed (new format)"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine format: raw format if --raw-format is specified, otherwise
|
||||
@ -333,12 +422,12 @@ def main() -> None:
|
||||
trio.run(run_dialer, args.port, args.destination, use_varint_format)
|
||||
else:
|
||||
# Run in listener mode with random available port if not specified
|
||||
trio.run(run_listener, args.port, use_varint_format)
|
||||
trio.run(run_listener, args.port, use_varint_format, args.raw_format)
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
logger.info("Interrupted by user")
|
||||
print("\n👋 Goodbye!")
|
||||
logger.info("Application interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\nError: {str(e)}")
|
||||
print(f"\n❌ Error: {str(e)}")
|
||||
logger.error("Error: %s", str(e))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@ -113,7 +113,7 @@ def parse_identify_response(response: bytes) -> Identify:
|
||||
|
||||
|
||||
def identify_handler_for(
|
||||
host: IHost, use_varint_format: bool = False
|
||||
host: IHost, use_varint_format: bool = True
|
||||
) -> StreamHandlerFn:
|
||||
async def handle_identify(stream: INetStream) -> None:
|
||||
# get observed address from ``stream``
|
||||
|
||||
@ -28,7 +28,7 @@ from libp2p.utils import (
|
||||
varint,
|
||||
)
|
||||
from libp2p.utils.varint import (
|
||||
decode_varint_from_bytes,
|
||||
read_length_prefixed_protobuf,
|
||||
)
|
||||
|
||||
from ..identify.identify import (
|
||||
@ -66,49 +66,8 @@ def identify_push_handler_for(
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
try:
|
||||
if use_varint_format:
|
||||
# Read length-prefixed identify message from the stream
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
break
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
if not length_bytes:
|
||||
logger.warning("No length prefix received from peer %s", peer_id)
|
||||
return
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
logger.warning("Incomplete message received from peer %s", peer_id)
|
||||
return
|
||||
else:
|
||||
# Read raw protobuf message from the stream
|
||||
# For raw format, we need to read all data before the stream is closed
|
||||
data = b""
|
||||
try:
|
||||
# Read all available data in a single operation
|
||||
data = await stream.read()
|
||||
except StreamClosed:
|
||||
# Try to read any remaining data
|
||||
try:
|
||||
data = await stream.read()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If we got no data, log a warning and return
|
||||
if not data:
|
||||
logger.warning(
|
||||
"No data received in raw format from peer %s", peer_id
|
||||
)
|
||||
return
|
||||
# Use the utility function to read the protobuf message
|
||||
data = await read_length_prefixed_protobuf(stream, use_varint_format)
|
||||
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
@ -119,6 +78,11 @@ def identify_push_handler_for(
|
||||
)
|
||||
|
||||
logger.debug("Successfully processed identify/push from peer %s", peer_id)
|
||||
|
||||
# Send acknowledgment to indicate successful processing
|
||||
# This ensures the sender knows the message was received before closing
|
||||
await stream.write(b"OK")
|
||||
|
||||
except StreamClosed:
|
||||
logger.debug(
|
||||
"Stream closed while processing identify/push from %s", peer_id
|
||||
@ -127,7 +91,10 @@ def identify_push_handler_for(
|
||||
logger.error("Error processing identify/push from %s: %s", peer_id, e)
|
||||
finally:
|
||||
# Close the stream after processing
|
||||
await stream.close()
|
||||
try:
|
||||
await stream.close()
|
||||
except Exception:
|
||||
pass # Ignore errors when closing
|
||||
|
||||
return handle_identify_push
|
||||
|
||||
@ -226,7 +193,20 @@ async def push_identify_to_peer(
|
||||
# Send raw protobuf message
|
||||
await stream.write(response)
|
||||
|
||||
# Close the stream
|
||||
# Wait for acknowledgment from the receiver with timeout
|
||||
# This ensures the message was processed before closing
|
||||
try:
|
||||
with trio.move_on_after(1.0): # 1 second timeout
|
||||
ack = await stream.read(2) # Read "OK" acknowledgment
|
||||
if ack != b"OK":
|
||||
logger.warning(
|
||||
"Unexpected acknowledgment from peer %s: %s", peer_id, ack
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("No acknowledgment received from peer %s: %s", peer_id, e)
|
||||
# Continue anyway, as the message might have been processed
|
||||
|
||||
# Close the stream after acknowledgment (or timeout)
|
||||
await stream.close()
|
||||
|
||||
logger.debug("Successfully pushed identify to peer %s", peer_id)
|
||||
|
||||
@ -45,6 +45,9 @@ from libp2p.stream_muxer.exceptions import (
|
||||
MuxedStreamReset,
|
||||
)
|
||||
|
||||
# Configure logger for this module
|
||||
logger = logging.getLogger("libp2p.stream_muxer.yamux")
|
||||
|
||||
PROTOCOL_ID = "/yamux/1.0.0"
|
||||
TYPE_DATA = 0x0
|
||||
TYPE_WINDOW_UPDATE = 0x1
|
||||
@ -98,13 +101,13 @@ class YamuxStream(IMuxedStream):
|
||||
# Flow control: Check if we have enough send window
|
||||
total_len = len(data)
|
||||
sent = 0
|
||||
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
while sent < total_len:
|
||||
# Wait for available window with timeout
|
||||
timeout = False
|
||||
async with self.window_lock:
|
||||
if self.send_window == 0:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Window is zero, waiting for update"
|
||||
)
|
||||
# Release lock and wait with timeout
|
||||
@ -152,12 +155,12 @@ class YamuxStream(IMuxedStream):
|
||||
"""
|
||||
if increment <= 0:
|
||||
# If increment is zero or negative, skip sending update
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Skipping window update"
|
||||
f"(increment={increment})"
|
||||
)
|
||||
return
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Sending window update with increment={increment}"
|
||||
)
|
||||
|
||||
@ -185,7 +188,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
# If the stream is closed for receiving and the buffer is empty, raise EOF
|
||||
if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id):
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Stream closed for receiving and buffer empty"
|
||||
)
|
||||
raise MuxedStreamEOF("Stream is closed for receiving")
|
||||
@ -198,7 +201,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
# If buffer is not available, check if stream is closed
|
||||
if buffer is None:
|
||||
logging.debug(f"Stream {self.stream_id}: No buffer available")
|
||||
logger.debug(f"Stream {self.stream_id}: No buffer available")
|
||||
raise MuxedStreamEOF("Stream buffer closed")
|
||||
|
||||
# If we have data in buffer, process it
|
||||
@ -210,34 +213,34 @@ class YamuxStream(IMuxedStream):
|
||||
# Send window update for the chunk we just read
|
||||
async with self.window_lock:
|
||||
self.recv_window += len(chunk)
|
||||
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
|
||||
logger.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
|
||||
await self.send_window_update(len(chunk), skip_lock=True)
|
||||
|
||||
# If stream is closed (FIN received) and buffer is empty, break
|
||||
if self.recv_closed and len(buffer) == 0:
|
||||
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer")
|
||||
logger.debug(f"Stream {self.stream_id}: Closed with empty buffer")
|
||||
break
|
||||
|
||||
# If stream was reset, raise reset error
|
||||
if self.reset_received:
|
||||
logging.debug(f"Stream {self.stream_id}: Stream was reset")
|
||||
logger.debug(f"Stream {self.stream_id}: Stream was reset")
|
||||
raise MuxedStreamReset("Stream was reset")
|
||||
|
||||
# Wait for more data or stream closure
|
||||
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
|
||||
logger.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
|
||||
await self.conn.stream_events[self.stream_id].wait()
|
||||
self.conn.stream_events[self.stream_id] = trio.Event()
|
||||
|
||||
# After loop exit, first check if we have data to return
|
||||
if data:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
|
||||
)
|
||||
return data
|
||||
|
||||
# No data accumulated, now check why we exited the loop
|
||||
if self.conn.event_shutting_down.is_set():
|
||||
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
|
||||
logger.debug(f"Stream {self.stream_id}: Connection shutting down")
|
||||
raise MuxedStreamEOF("Connection shut down")
|
||||
|
||||
# Return empty data
|
||||
@ -246,7 +249,7 @@ class YamuxStream(IMuxedStream):
|
||||
data = await self.conn.read_stream(self.stream_id, n)
|
||||
async with self.window_lock:
|
||||
self.recv_window += len(data)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Sending window update after read, "
|
||||
f"increment={len(data)}"
|
||||
)
|
||||
@ -255,7 +258,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self.send_closed:
|
||||
logging.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
||||
)
|
||||
@ -271,7 +274,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
async def reset(self) -> None:
|
||||
if not self.closed:
|
||||
logging.debug(f"Resetting stream {self.stream_id}")
|
||||
logger.debug(f"Resetting stream {self.stream_id}")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
||||
)
|
||||
@ -349,7 +352,7 @@ class Yamux(IMuxedConn):
|
||||
self._nursery: Nursery | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
logging.debug(f"Starting Yamux for {self.peer_id}")
|
||||
logger.debug(f"Starting Yamux for {self.peer_id}")
|
||||
if self.event_started.is_set():
|
||||
return
|
||||
async with trio.open_nursery() as nursery:
|
||||
@ -362,7 +365,7 @@ class Yamux(IMuxedConn):
|
||||
return self.is_initiator_value
|
||||
|
||||
async def close(self, error_code: int = GO_AWAY_NORMAL) -> None:
|
||||
logging.debug(f"Closing Yamux connection with code {error_code}")
|
||||
logger.debug(f"Closing Yamux connection with code {error_code}")
|
||||
async with self.streams_lock:
|
||||
if not self.event_shutting_down.is_set():
|
||||
try:
|
||||
@ -371,7 +374,7 @@ class Yamux(IMuxedConn):
|
||||
)
|
||||
await self.secured_conn.write(header)
|
||||
except Exception as e:
|
||||
logging.debug(f"Failed to send GO_AWAY: {e}")
|
||||
logger.debug(f"Failed to send GO_AWAY: {e}")
|
||||
self.event_shutting_down.set()
|
||||
for stream in self.streams.values():
|
||||
stream.closed = True
|
||||
@ -382,12 +385,12 @@ class Yamux(IMuxedConn):
|
||||
self.stream_events.clear()
|
||||
try:
|
||||
await self.secured_conn.close()
|
||||
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
except Exception as e:
|
||||
logging.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
|
||||
logger.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
|
||||
self.event_closed.set()
|
||||
if self.on_close:
|
||||
logging.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
|
||||
logger.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
|
||||
if inspect.iscoroutinefunction(self.on_close):
|
||||
if self.on_close is not None:
|
||||
await self.on_close()
|
||||
@ -416,7 +419,7 @@ class Yamux(IMuxedConn):
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0
|
||||
)
|
||||
logging.debug(f"Sending SYN header for stream {stream_id}")
|
||||
logger.debug(f"Sending SYN header for stream {stream_id}")
|
||||
await self.secured_conn.write(header)
|
||||
return stream
|
||||
except Exception as e:
|
||||
@ -424,32 +427,32 @@ class Yamux(IMuxedConn):
|
||||
raise e
|
||||
|
||||
async def accept_stream(self) -> IMuxedStream:
|
||||
logging.debug("Waiting for new stream")
|
||||
logger.debug("Waiting for new stream")
|
||||
try:
|
||||
stream = await self.new_stream_receive_channel.receive()
|
||||
logging.debug(f"Received stream {stream.stream_id}")
|
||||
logger.debug(f"Received stream {stream.stream_id}")
|
||||
return stream
|
||||
except trio.EndOfChannel:
|
||||
raise MuxedStreamError("No new streams available")
|
||||
|
||||
async def read_stream(self, stream_id: int, n: int = -1) -> bytes:
|
||||
logging.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
|
||||
logger.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
|
||||
if n is None:
|
||||
n = -1
|
||||
|
||||
while True:
|
||||
async with self.streams_lock:
|
||||
if stream_id not in self.streams:
|
||||
logging.debug(f"Stream {self.peer_id}:{stream_id} unknown")
|
||||
logger.debug(f"Stream {self.peer_id}:{stream_id} unknown")
|
||||
raise MuxedStreamEOF("Stream closed")
|
||||
if self.event_shutting_down.is_set():
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}: connection shutting down"
|
||||
)
|
||||
raise MuxedStreamEOF("Connection shut down")
|
||||
stream = self.streams[stream_id]
|
||||
buffer = self.stream_buffers.get(stream_id)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}: "
|
||||
f"closed={stream.closed}, "
|
||||
f"recv_closed={stream.recv_closed}, "
|
||||
@ -457,7 +460,7 @@ class Yamux(IMuxedConn):
|
||||
f"buffer_len={len(buffer) if buffer else 0}"
|
||||
)
|
||||
if buffer is None:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"Buffer gone, assuming closed"
|
||||
)
|
||||
@ -470,7 +473,7 @@ class Yamux(IMuxedConn):
|
||||
else:
|
||||
data = bytes(buffer[:n])
|
||||
del buffer[:n]
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Returning {len(data)} bytes"
|
||||
f"from stream {self.peer_id}:{stream_id}, "
|
||||
f"buffer_len={len(buffer)}"
|
||||
@ -478,7 +481,7 @@ class Yamux(IMuxedConn):
|
||||
return data
|
||||
# If reset received and buffer is empty, raise reset
|
||||
if stream.reset_received:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"reset_received=True, raising MuxedStreamReset"
|
||||
)
|
||||
@ -491,7 +494,7 @@ class Yamux(IMuxedConn):
|
||||
else:
|
||||
data = bytes(buffer[:n])
|
||||
del buffer[:n]
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Returning {len(data)} bytes"
|
||||
f"from stream {self.peer_id}:{stream_id}, "
|
||||
f"buffer_len={len(buffer)}"
|
||||
@ -499,21 +502,21 @@ class Yamux(IMuxedConn):
|
||||
return data
|
||||
# Check if stream is closed
|
||||
if stream.closed:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"closed=True, raising MuxedStreamReset"
|
||||
)
|
||||
raise MuxedStreamReset("Stream is reset or closed")
|
||||
# Check if recv_closed and buffer empty
|
||||
if stream.recv_closed:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"recv_closed=True, buffer empty, raising EOF"
|
||||
)
|
||||
raise MuxedStreamEOF("Stream is closed for receiving")
|
||||
|
||||
# Wait for data if stream is still open
|
||||
logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
|
||||
logger.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
|
||||
try:
|
||||
await self.stream_events[stream_id].wait()
|
||||
self.stream_events[stream_id] = trio.Event()
|
||||
@ -528,7 +531,7 @@ class Yamux(IMuxedConn):
|
||||
try:
|
||||
header = await self.secured_conn.read(HEADER_SIZE)
|
||||
if not header or len(header) < HEADER_SIZE:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Connection closed orincomplete header for peer {self.peer_id}"
|
||||
)
|
||||
self.event_shutting_down.set()
|
||||
@ -537,7 +540,7 @@ class Yamux(IMuxedConn):
|
||||
version, typ, flags, stream_id, length = struct.unpack(
|
||||
YAMUX_HEADER_FORMAT, header
|
||||
)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received header for peer {self.peer_id}:"
|
||||
f"type={typ}, flags={flags}, stream_id={stream_id},"
|
||||
f"length={length}"
|
||||
@ -558,7 +561,7 @@ class Yamux(IMuxedConn):
|
||||
0,
|
||||
)
|
||||
await self.secured_conn.write(ack_header)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Sending stream {stream_id}"
|
||||
f"to channel for peer {self.peer_id}"
|
||||
)
|
||||
@ -576,7 +579,7 @@ class Yamux(IMuxedConn):
|
||||
elif typ == TYPE_DATA and flags & FLAG_RST:
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Resetting stream {stream_id} for peer {self.peer_id}"
|
||||
)
|
||||
self.streams[stream_id].closed = True
|
||||
@ -585,27 +588,27 @@ class Yamux(IMuxedConn):
|
||||
elif typ == TYPE_DATA and flags & FLAG_ACK:
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received ACK for stream"
|
||||
f"{stream_id} for peer {self.peer_id}"
|
||||
)
|
||||
elif typ == TYPE_GO_AWAY:
|
||||
error_code = length
|
||||
if error_code == GO_AWAY_NORMAL:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received GO_AWAY for peer"
|
||||
f"{self.peer_id}: Normal termination"
|
||||
)
|
||||
elif error_code == GO_AWAY_PROTOCOL_ERROR:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Received GO_AWAY for peer{self.peer_id}: Protocol error"
|
||||
)
|
||||
elif error_code == GO_AWAY_INTERNAL_ERROR:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Received GO_AWAY for peer {self.peer_id}: Internal error"
|
||||
)
|
||||
else:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Received GO_AWAY for peer {self.peer_id}"
|
||||
f"with unknown error code: {error_code}"
|
||||
)
|
||||
@ -614,7 +617,7 @@ class Yamux(IMuxedConn):
|
||||
break
|
||||
elif typ == TYPE_PING:
|
||||
if flags & FLAG_SYN:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received ping request with value"
|
||||
f"{length} for peer {self.peer_id}"
|
||||
)
|
||||
@ -623,7 +626,7 @@ class Yamux(IMuxedConn):
|
||||
)
|
||||
await self.secured_conn.write(ping_header)
|
||||
elif flags & FLAG_ACK:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received ping response with value"
|
||||
f"{length} for peer {self.peer_id}"
|
||||
)
|
||||
@ -637,7 +640,7 @@ class Yamux(IMuxedConn):
|
||||
self.stream_buffers[stream_id].extend(data)
|
||||
self.stream_events[stream_id].set()
|
||||
if flags & FLAG_FIN:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received FIN for stream {self.peer_id}:"
|
||||
f"{stream_id}, marking recv_closed"
|
||||
)
|
||||
@ -645,7 +648,7 @@ class Yamux(IMuxedConn):
|
||||
if self.streams[stream_id].send_closed:
|
||||
self.streams[stream_id].closed = True
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading data for stream {stream_id}: {e}")
|
||||
logger.error(f"Error reading data for stream {stream_id}: {e}")
|
||||
# Mark stream as closed on read error
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
@ -659,7 +662,7 @@ class Yamux(IMuxedConn):
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
async with stream.window_lock:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received window update for stream"
|
||||
f"{self.peer_id}:{stream_id},"
|
||||
f" increment: {increment}"
|
||||
@ -674,7 +677,7 @@ class Yamux(IMuxedConn):
|
||||
and details.get("requested_count") == 2
|
||||
and details.get("received_count") == 0
|
||||
):
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Stream closed cleanly for peer {self.peer_id}"
|
||||
+ f" (IncompleteReadError: {details})"
|
||||
)
|
||||
@ -682,15 +685,32 @@ class Yamux(IMuxedConn):
|
||||
await self._cleanup_on_error()
|
||||
break
|
||||
else:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Error in handle_incoming for peer {self.peer_id}: "
|
||||
+ f"{type(e).__name__}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
logging.error(
|
||||
f"Error in handle_incoming for peer {self.peer_id}: "
|
||||
+ f"{type(e).__name__}: {str(e)}"
|
||||
)
|
||||
# Handle RawConnError with more nuance
|
||||
if isinstance(e, RawConnError):
|
||||
error_msg = str(e)
|
||||
# If RawConnError is empty, it's likely normal cleanup
|
||||
if not error_msg.strip():
|
||||
logger.info(
|
||||
f"RawConnError (empty) during cleanup for peer "
|
||||
f"{self.peer_id} (normal connection shutdown)"
|
||||
)
|
||||
else:
|
||||
# Log non-empty RawConnError as warning
|
||||
logger.warning(
|
||||
f"RawConnError during connection handling for peer "
|
||||
f"{self.peer_id}: {error_msg}"
|
||||
)
|
||||
else:
|
||||
# Log all other errors normally
|
||||
logger.error(
|
||||
f"Error in handle_incoming for peer {self.peer_id}: "
|
||||
+ f"{type(e).__name__}: {str(e)}"
|
||||
)
|
||||
# Don't crash the whole connection for temporary errors
|
||||
if self.event_shutting_down.is_set() or isinstance(
|
||||
e, (RawConnError, OSError)
|
||||
@ -720,9 +740,9 @@ class Yamux(IMuxedConn):
|
||||
# Close the secured connection
|
||||
try:
|
||||
await self.secured_conn.close()
|
||||
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
except Exception as close_error:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Error closing secured_conn for peer {self.peer_id}: {close_error}"
|
||||
)
|
||||
|
||||
@ -731,14 +751,14 @@ class Yamux(IMuxedConn):
|
||||
|
||||
# Call on_close callback if provided
|
||||
if self.on_close:
|
||||
logging.debug(f"Calling on_close for peer {self.peer_id}")
|
||||
logger.debug(f"Calling on_close for peer {self.peer_id}")
|
||||
try:
|
||||
if inspect.iscoroutinefunction(self.on_close):
|
||||
await self.on_close()
|
||||
else:
|
||||
self.on_close()
|
||||
except Exception as callback_error:
|
||||
logging.error(f"Error in on_close callback: {callback_error}")
|
||||
logger.error(f"Error in on_close callback: {callback_error}")
|
||||
|
||||
# Cancel nursery tasks
|
||||
if self._nursery:
|
||||
|
||||
@ -9,6 +9,7 @@ from libp2p.utils.varint import (
|
||||
read_varint_prefixed_bytes,
|
||||
decode_varint_from_bytes,
|
||||
decode_varint_with_size,
|
||||
read_length_prefixed_protobuf,
|
||||
)
|
||||
from libp2p.utils.version import (
|
||||
get_agent_version,
|
||||
@ -24,4 +25,5 @@ __all__ = [
|
||||
"read_varint_prefixed_bytes",
|
||||
"decode_varint_from_bytes",
|
||||
"decode_varint_with_size",
|
||||
"read_length_prefixed_protobuf",
|
||||
]
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
from typing import BinaryIO
|
||||
|
||||
from libp2p.abc import INetStream
|
||||
from libp2p.exceptions import (
|
||||
ParseError,
|
||||
)
|
||||
@ -25,42 +27,41 @@ HIGH_MASK = 2**7
|
||||
SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7
|
||||
|
||||
|
||||
def encode_uvarint(number: int) -> bytes:
|
||||
"""Pack `number` into varint bytes."""
|
||||
buf = b""
|
||||
while True:
|
||||
towrite = number & 0x7F
|
||||
number >>= 7
|
||||
if number:
|
||||
buf += bytes((towrite | 0x80,))
|
||||
else:
|
||||
buf += bytes((towrite,))
|
||||
def encode_uvarint(value: int) -> bytes:
|
||||
"""Encode an unsigned integer as a varint."""
|
||||
if value < 0:
|
||||
raise ValueError("Cannot encode negative value as uvarint")
|
||||
|
||||
result = bytearray()
|
||||
while value >= 0x80:
|
||||
result.append((value & 0x7F) | 0x80)
|
||||
value >>= 7
|
||||
result.append(value & 0x7F)
|
||||
return bytes(result)
|
||||
|
||||
|
||||
def decode_uvarint(data: bytes) -> int:
|
||||
"""Decode a varint from bytes."""
|
||||
if not data:
|
||||
raise ParseError("Unexpected end of data")
|
||||
|
||||
result = 0
|
||||
shift = 0
|
||||
|
||||
for byte in data:
|
||||
result |= (byte & 0x7F) << shift
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
return buf
|
||||
shift += 7
|
||||
if shift >= 64:
|
||||
raise ValueError("Varint too long")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def decode_varint_from_bytes(data: bytes) -> int:
|
||||
"""
|
||||
Decode a varint from bytes and return the value.
|
||||
|
||||
This is a synchronous version of decode_uvarint_from_stream for already-read bytes.
|
||||
"""
|
||||
res = 0
|
||||
for shift in itertools.count(0, 7):
|
||||
if shift > SHIFT_64_BIT_MAX:
|
||||
raise ParseError("Integer is too large...")
|
||||
|
||||
if not data:
|
||||
raise ParseError("Unexpected end of data")
|
||||
|
||||
value = data[0]
|
||||
data = data[1:]
|
||||
|
||||
res += (value & LOW_MASK) << shift
|
||||
|
||||
if not value & HIGH_MASK:
|
||||
break
|
||||
return res
|
||||
"""Decode a varint from bytes (alias for decode_uvarint for backward comp)."""
|
||||
return decode_uvarint(data)
|
||||
|
||||
|
||||
async def decode_uvarint_from_stream(reader: Reader) -> int:
|
||||
@ -84,34 +85,33 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
|
||||
|
||||
def decode_varint_with_size(data: bytes) -> tuple[int, int]:
|
||||
"""
|
||||
Decode a varint from bytes and return (value, bytes_consumed).
|
||||
Returns (0, 0) if the data doesn't start with a valid varint.
|
||||
Decode a varint from bytes and return both the value and the number of bytes
|
||||
consumed.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: (value, bytes_consumed)
|
||||
|
||||
"""
|
||||
try:
|
||||
# Calculate how many bytes the varint consumes
|
||||
varint_size = 0
|
||||
for i, byte in enumerate(data):
|
||||
varint_size += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
result = 0
|
||||
shift = 0
|
||||
bytes_consumed = 0
|
||||
|
||||
if varint_size == 0:
|
||||
return 0, 0
|
||||
for byte in data:
|
||||
result |= (byte & 0x7F) << shift
|
||||
bytes_consumed += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
shift += 7
|
||||
if shift >= 64:
|
||||
raise ValueError("Varint too long")
|
||||
|
||||
# Extract just the varint bytes
|
||||
varint_bytes = data[:varint_size]
|
||||
|
||||
# Decode the varint
|
||||
value = decode_varint_from_bytes(varint_bytes)
|
||||
|
||||
return value, varint_size
|
||||
except Exception:
|
||||
return 0, 0
|
||||
return result, bytes_consumed
|
||||
|
||||
|
||||
def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
|
||||
varint_len = encode_uvarint(len(msg_bytes))
|
||||
return varint_len + msg_bytes
|
||||
def encode_varint_prefixed(data: bytes) -> bytes:
|
||||
"""Encode data with a varint length prefix."""
|
||||
length_bytes = encode_uvarint(len(data))
|
||||
return length_bytes + data
|
||||
|
||||
|
||||
async def read_varint_prefixed_bytes(reader: Reader) -> bytes:
|
||||
@ -138,3 +138,95 @@ async def read_delim(reader: Reader) -> bytes:
|
||||
f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes!r}'
|
||||
)
|
||||
return msg_bytes[:-1]
|
||||
|
||||
|
||||
def read_varint_prefixed_bytes_sync(
|
||||
stream: BinaryIO, max_length: int = 1024 * 1024
|
||||
) -> bytes:
|
||||
"""
|
||||
Read varint-prefixed bytes from a stream.
|
||||
|
||||
Args:
|
||||
stream: A stream-like object with a read() method
|
||||
max_length: Maximum allowed data length to prevent memory exhaustion
|
||||
|
||||
Returns:
|
||||
bytes: The data without the length prefix
|
||||
|
||||
Raises:
|
||||
ValueError: If the length prefix is invalid or too large
|
||||
EOFError: If the stream ends unexpectedly
|
||||
|
||||
"""
|
||||
# Read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
byte_data = stream.read(1)
|
||||
if not byte_data:
|
||||
raise EOFError("Stream ended while reading varint length prefix")
|
||||
|
||||
length_bytes += byte_data
|
||||
if byte_data[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
# Decode the length
|
||||
length = decode_uvarint(length_bytes)
|
||||
|
||||
if length > max_length:
|
||||
raise ValueError(f"Data length {length} exceeds maximum allowed {max_length}")
|
||||
|
||||
# Read the data
|
||||
data = stream.read(length)
|
||||
if len(data) != length:
|
||||
raise EOFError(f"Expected {length} bytes, got {len(data)}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def read_length_prefixed_protobuf(
|
||||
stream: INetStream, use_varint_format: bool = True, max_length: int = 1024 * 1024
|
||||
) -> bytes:
|
||||
"""Read a protobuf message from a stream, handling both formats."""
|
||||
if use_varint_format:
|
||||
# Read length-prefixed protobuf message from the stream
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
raise Exception("No length prefix received")
|
||||
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
if msg_length > max_length:
|
||||
raise Exception(
|
||||
f"Message length {msg_length} exceeds maximum allowed {max_length}"
|
||||
)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
raise Exception(
|
||||
f"Incomplete message: expected {msg_length}, got {len(data)}"
|
||||
)
|
||||
|
||||
return data
|
||||
else:
|
||||
# Read raw protobuf message from the stream
|
||||
# For raw format, read all available data in one go
|
||||
data = await stream.read()
|
||||
|
||||
# If we got no data, raise an exception
|
||||
if not data:
|
||||
raise Exception("No data received in raw format")
|
||||
|
||||
if len(data) > max_length:
|
||||
raise Exception(
|
||||
f"Message length {len(data)} exceeds maximum allowed {max_length}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
1
newsfragments/784.bugfix.rst
Normal file
1
newsfragments/784.bugfix.rst
Normal file
@ -0,0 +1 @@
|
||||
Fixed incorrect handling of raw protobuf format in identify push protocol. The identify push example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats.
|
||||
1
newsfragments/784.internal.rst
Normal file
1
newsfragments/784.internal.rst
Normal file
@ -0,0 +1 @@
|
||||
Yamux RawConnError Logging Refactor - Improved error handling and debug logging
|
||||
@ -0,0 +1,552 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.identity.identify_push.identify_push import (
|
||||
ID_PUSH,
|
||||
identify_push_handler_for,
|
||||
push_identify_to_peer,
|
||||
push_identify_to_peers,
|
||||
)
|
||||
from tests.utils.factories import host_pair_factory
|
||||
|
||||
logger = logging.getLogger("libp2p.identity.identify-push-integration-test")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_protocol_varint_format_integration(security_protocol):
|
||||
"""Test identify/push protocol with varint format in real network scenario."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to host_b so it has something to push
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/1"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/2"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
|
||||
)
|
||||
|
||||
# Push identify information from host_b to host_a
|
||||
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
|
||||
# Check that addresses were added
|
||||
addrs = peerstore_a.addrs(peer_id_b)
|
||||
assert len(addrs) > 0
|
||||
|
||||
# Check that protocols were added
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
# The protocols should include the dummy protocols we added
|
||||
assert len(protocols) >= 2 # Should include the dummy protocols
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_protocol_raw_format_integration(security_protocol):
|
||||
"""Test identify/push protocol with raw format in real network scenario."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=False)
|
||||
)
|
||||
|
||||
# Push identify information from host_b to host_a
|
||||
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=False)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
|
||||
# Check that addresses were added
|
||||
addrs = peerstore_a.addrs(peer_id_b)
|
||||
assert len(addrs) > 0
|
||||
|
||||
# Check that protocols were added
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
assert len(protocols) >= 1 # Should include the dummy protocol
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_default_format_behavior(security_protocol):
|
||||
"""Test identify/push protocol uses correct default format."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Use default identify/push handler (should use varint format)
|
||||
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
|
||||
|
||||
# Push identify information from host_b to host_a
|
||||
await push_identify_to_peer(host_b, host_a.get_id())
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
|
||||
# Check that protocols were added
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
assert len(protocols) >= 1 # Should include the dummy protocol
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_cross_format_compatibility_varint_to_raw(
|
||||
security_protocol,
|
||||
):
|
||||
"""Test varint pusher with raw listener compatibility."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Use an event to signal when handler is ready
|
||||
handler_ready = trio.Event()
|
||||
|
||||
# Create a wrapper handler that signals when ready
|
||||
original_handler = identify_push_handler_for(host_a, use_varint_format=False)
|
||||
|
||||
async def wrapped_handler(stream):
|
||||
handler_ready.set() # Signal that handler is ready
|
||||
await original_handler(stream)
|
||||
|
||||
# Host A uses raw format with wrapped handler
|
||||
host_a.set_stream_handler(ID_PUSH, wrapped_handler)
|
||||
|
||||
# Host B pushes with varint format (should fail gracefully)
|
||||
success = await push_identify_to_peer(
|
||||
host_b, host_a.get_id(), use_varint_format=True
|
||||
)
|
||||
# This should fail due to format mismatch
|
||||
# Note: The format detection might be more robust than expected
|
||||
# so we just check that the operation completes
|
||||
assert isinstance(success, bool)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_cross_format_compatibility_raw_to_varint(
|
||||
security_protocol,
|
||||
):
|
||||
"""Test raw pusher with varint listener compatibility."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Use an event to signal when handler is ready
|
||||
handler_ready = trio.Event()
|
||||
|
||||
# Create a wrapper handler that signals when ready
|
||||
original_handler = identify_push_handler_for(host_a, use_varint_format=True)
|
||||
|
||||
async def wrapped_handler(stream):
|
||||
handler_ready.set() # Signal that handler is ready
|
||||
await original_handler(stream)
|
||||
|
||||
# Host A uses varint format with wrapped handler
|
||||
host_a.set_stream_handler(ID_PUSH, wrapped_handler)
|
||||
|
||||
# Host B pushes with raw format (should fail gracefully)
|
||||
success = await push_identify_to_peer(
|
||||
host_b, host_a.get_id(), use_varint_format=False
|
||||
)
|
||||
# This should fail due to format mismatch
|
||||
# Note: The format detection might be more robust than expected
|
||||
# so we just check that the operation completes
|
||||
assert isinstance(success, bool)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_multiple_peers_integration(security_protocol):
|
||||
"""Test identify/push protocol with multiple peers."""
|
||||
# Create two hosts using the factory
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create a third host following the pattern from test_identify_push.py
|
||||
import multiaddr
|
||||
|
||||
from libp2p import new_host
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
|
||||
# Create a new key pair for host_c
|
||||
key_pair_c = create_new_key_pair()
|
||||
host_c = new_host(key_pair=key_pair_c)
|
||||
|
||||
# Set up identify/push handlers on all hosts
|
||||
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
|
||||
host_b.set_stream_handler(ID_PUSH, identify_push_handler_for(host_b))
|
||||
host_c.set_stream_handler(ID_PUSH, identify_push_handler_for(host_c))
|
||||
|
||||
# Start listening on a random port using the run context manager
|
||||
listen_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
async with host_c.run([listen_addr]):
|
||||
# Connect host_c to host_a and host_b using the correct pattern
|
||||
await host_c.connect(info_from_p2p_addr(host_a.get_addrs()[0]))
|
||||
await host_c.connect(info_from_p2p_addr(host_b.get_addrs()[0]))
|
||||
|
||||
# Push identify information from host_a to all connected peers
|
||||
await push_identify_to_peers(host_a)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Check that host_b's peerstore has been updated with host_a's information
|
||||
peerstore_b = host_b.get_peerstore()
|
||||
peer_id_a = host_a.get_id()
|
||||
|
||||
# Check that the peer is in the peerstore
|
||||
assert peer_id_a in peerstore_b.peer_ids()
|
||||
|
||||
# Check that host_c's peerstore has been updated with host_a's information
|
||||
peerstore_c = host_c.get_peerstore()
|
||||
|
||||
# Check that the peer is in the peerstore
|
||||
assert peer_id_a in peerstore_c.peer_ids()
|
||||
|
||||
# Test for push_identify to only connected peers and not all peers
|
||||
# Disconnect a from c.
|
||||
await host_c.disconnect(host_a.get_id())
|
||||
|
||||
await push_identify_to_peers(host_c)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Check that host_a's peerstore has not been updated with host_c's info
|
||||
assert host_c.get_id() not in host_a.get_peerstore().peer_ids()
|
||||
# Check that host_b's peerstore has been updated with host_c's info
|
||||
assert host_c.get_id() in host_b.get_peerstore().peer_ids()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_large_message_handling(security_protocol):
|
||||
"""Test identify/push protocol handles large messages with many protocols."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add many protocols to make the message larger
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
for i in range(10):
|
||||
host_b.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler)
|
||||
|
||||
# Also add some protocols to host_a to ensure it has protocols to push
|
||||
for i in range(5):
|
||||
host_a.set_stream_handler(TProtocol(f"/test/protocol/a{i}"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
|
||||
)
|
||||
|
||||
# Push identify information from host_b to host_a
|
||||
success = await push_identify_to_peer(
|
||||
host_b, host_a.get_id(), use_varint_format=True
|
||||
)
|
||||
assert success
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated with all protocols
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
assert len(protocols) >= 10 # Should include the dummy protocols
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_peerstore_update_completeness(security_protocol):
|
||||
"""Test that identify/push updates all relevant peerstore information."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
|
||||
|
||||
# Push identify information from host_b to host_a
|
||||
await push_identify_to_peer(host_b, host_a.get_id())
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
|
||||
# Check that protocols were added
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
assert len(protocols) > 0
|
||||
|
||||
# Check that addresses were added
|
||||
addrs = peerstore_a.addrs(peer_id_b)
|
||||
assert len(addrs) > 0
|
||||
|
||||
# Check that public key was added
|
||||
pubkey = peerstore_a.pubkey(peer_id_b)
|
||||
assert pubkey is not None
|
||||
assert pubkey.serialize() == host_b.get_public_key().serialize()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_concurrent_requests(security_protocol):
|
||||
"""Test identify/push protocol handles concurrent requests properly."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
|
||||
|
||||
# Make multiple concurrent push requests
|
||||
results = []
|
||||
|
||||
async def push_identify():
|
||||
result = await push_identify_to_peer(host_b, host_a.get_id())
|
||||
results.append(result)
|
||||
|
||||
# Run multiple concurrent pushes using nursery
|
||||
async with trio.open_nursery() as nursery:
|
||||
for _ in range(3):
|
||||
nursery.start_soon(push_identify)
|
||||
|
||||
# All should succeed
|
||||
assert len(results) == 3
|
||||
assert all(results)
|
||||
|
||||
# Wait a bit for the pushes to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
assert len(protocols) > 0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_stream_handling(security_protocol):
|
||||
"""Test identify/push protocol properly handles stream lifecycle."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
|
||||
|
||||
# Push identify information from host_b to host_a
|
||||
success = await push_identify_to_peer(host_b, host_a.get_id())
|
||||
assert success
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
assert len(protocols) > 0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_error_handling(security_protocol):
|
||||
"""Test identify/push protocol handles errors gracefully."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create a handler that raises an exception but catches it to prevent test
|
||||
# failure
|
||||
async def error_handler(stream):
|
||||
try:
|
||||
await stream.close()
|
||||
raise Exception("Test error")
|
||||
except Exception:
|
||||
# Catch the exception to prevent it from propagating up
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(ID_PUSH, error_handler)
|
||||
|
||||
# Push should complete (message sent) but handler should fail gracefully
|
||||
success = await push_identify_to_peer(host_b, host_a.get_id())
|
||||
assert success # The push operation itself succeeds (message sent)
|
||||
|
||||
# Wait a bit for the handler to process
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that the error was handled gracefully (no test failure)
|
||||
# The handler caught the exception and didn't propagate it
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_message_equivalence_real_network(security_protocol):
|
||||
"""Test that both formats produce equivalent peerstore updates in real network."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Test varint format
|
||||
host_a.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
|
||||
)
|
||||
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Get peerstore state after varint push
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
protocols_varint = peerstore_a.get_protocols(peer_id_b)
|
||||
addrs_varint = peerstore_a.addrs(peer_id_b)
|
||||
|
||||
# Clear peerstore for next test
|
||||
peerstore_a.clear_addrs(peer_id_b)
|
||||
peerstore_a.clear_protocol_data(peer_id_b)
|
||||
|
||||
# Test raw format
|
||||
host_a.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=False)
|
||||
)
|
||||
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=False)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Get peerstore state after raw push
|
||||
protocols_raw = peerstore_a.get_protocols(peer_id_b)
|
||||
addrs_raw = peerstore_a.addrs(peer_id_b)
|
||||
|
||||
# Both should produce equivalent peerstore updates
|
||||
# Check that both formats successfully updated protocols
|
||||
assert protocols_varint is not None
|
||||
assert protocols_raw is not None
|
||||
assert len(protocols_varint) > 0
|
||||
assert len(protocols_raw) > 0
|
||||
|
||||
# Check that both formats successfully updated addresses
|
||||
assert addrs_varint is not None
|
||||
assert addrs_raw is not None
|
||||
assert len(addrs_varint) > 0
|
||||
assert len(addrs_raw) > 0
|
||||
|
||||
# Both should contain the same essential information
|
||||
# (exact address lists might differ due to format-specific handling)
|
||||
assert set(protocols_varint) == set(protocols_raw)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_with_observed_address(security_protocol):
|
||||
"""Test identify/push protocol includes observed address information."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
|
||||
|
||||
# Get host_b's address as observed by host_a
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
host_b_addr = host_b.get_addrs()[0]
|
||||
observed_multiaddr = Multiaddr(str(host_b_addr))
|
||||
|
||||
# Push identify information with observed address
|
||||
await push_identify_to_peer(
|
||||
host_b, host_a.get_id(), observed_multiaddr=observed_multiaddr
|
||||
)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
|
||||
# Check that addresses were added
|
||||
addrs = peerstore_a.addrs(peer_id_b)
|
||||
assert len(addrs) > 0
|
||||
|
||||
# The observed address should be among the stored addresses
|
||||
addr_strings = [str(addr) for addr in addrs]
|
||||
assert str(observed_multiaddr) in addr_strings
|
||||
Reference in New Issue
Block a user