mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'main' into todo/handletimeout
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import base64
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
@ -72,14 +73,52 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
|
||||
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
|
||||
|
||||
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
|
||||
format_flag = "--raw-format" if not use_varint_format else ""
|
||||
print(
|
||||
f"First host listening (using {format_name} format). "
|
||||
f"Run this from another console:\n\n"
|
||||
f"identify-demo "
|
||||
f"-d {client_addr}\n"
|
||||
f"identify-demo {format_flag} -d {client_addr}\n"
|
||||
)
|
||||
print("Waiting for incoming identify request...")
|
||||
await trio.sleep_forever()
|
||||
|
||||
# Add a custom handler to show connection events
|
||||
async def custom_identify_handler(stream):
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
print(f"\n🔗 Received identify request from peer: {peer_id}")
|
||||
|
||||
# Show remote address in multiaddr format
|
||||
try:
|
||||
from libp2p.identity.identify.identify import (
|
||||
_remote_address_to_multiaddr,
|
||||
)
|
||||
|
||||
remote_address = stream.get_remote_address()
|
||||
if remote_address:
|
||||
observed_multiaddr = _remote_address_to_multiaddr(
|
||||
remote_address
|
||||
)
|
||||
# Add the peer ID to create a complete multiaddr
|
||||
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
|
||||
print(f" Remote address: {complete_multiaddr}")
|
||||
else:
|
||||
print(f" Remote address: {remote_address}")
|
||||
except Exception:
|
||||
print(f" Remote address: {stream.get_remote_address()}")
|
||||
|
||||
# Call the original handler
|
||||
await identify_handler(stream)
|
||||
|
||||
print(f"✅ Successfully processed identify request from {peer_id}")
|
||||
|
||||
# Replace the handler with our custom one
|
||||
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler)
|
||||
|
||||
try:
|
||||
await trio.sleep_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("\n🛑 Shutting down listener...")
|
||||
logger.info("Listener interrupted by user")
|
||||
return
|
||||
|
||||
else:
|
||||
# Create second host (dialer)
|
||||
@ -93,25 +132,74 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
|
||||
info = info_from_p2p_addr(maddr)
|
||||
print(f"Second host connecting to peer: {info.peer_id}")
|
||||
|
||||
await host_b.connect(info)
|
||||
try:
|
||||
await host_b.connect(info)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "unable to connect" in error_msg or "SwarmException" in error_msg:
|
||||
print(f"\n❌ Cannot connect to peer: {info.peer_id}")
|
||||
print(f" Address: {destination}")
|
||||
print(f" Error: {error_msg}")
|
||||
print(
|
||||
"\n💡 Make sure the peer is running and the address is correct."
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Re-raise other exceptions
|
||||
raise
|
||||
|
||||
stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,))
|
||||
|
||||
try:
|
||||
print("Starting identify protocol...")
|
||||
|
||||
# Read the complete response (could be either format)
|
||||
# Read a larger chunk to get all the data before stream closes
|
||||
response = await stream.read(8192) # Read enough data in one go
|
||||
# Read the response using the utility function
|
||||
from libp2p.utils.varint import read_length_prefixed_protobuf
|
||||
|
||||
response = await read_length_prefixed_protobuf(
|
||||
stream, use_varint_format
|
||||
)
|
||||
full_response = response
|
||||
|
||||
await stream.close()
|
||||
|
||||
# Parse the response using the robust protocol-level function
|
||||
# This handles both old and new formats automatically
|
||||
identify_msg = parse_identify_response(response)
|
||||
identify_msg = parse_identify_response(full_response)
|
||||
print_identify_response(identify_msg)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Identify protocol error: {e}")
|
||||
error_msg = str(e)
|
||||
print(f"Identify protocol error: {error_msg}")
|
||||
|
||||
# Check for specific format mismatch errors
|
||||
if "Error parsing message" in error_msg or "DecodeError" in error_msg:
|
||||
print("\n" + "=" * 60)
|
||||
print("FORMAT MISMATCH DETECTED!")
|
||||
print("=" * 60)
|
||||
if use_varint_format:
|
||||
print(
|
||||
"You are using length-prefixed format (default) but the "
|
||||
"listener"
|
||||
)
|
||||
print("is using raw protobuf format.")
|
||||
print(
|
||||
"\nTo fix this, run the dialer with the --raw-format flag:"
|
||||
)
|
||||
print(f"identify-demo --raw-format -d {destination}")
|
||||
else:
|
||||
print("You are using raw protobuf format but the listener")
|
||||
print("is using length-prefixed format (default).")
|
||||
print(
|
||||
"\nTo fix this, run the dialer without the --raw-format "
|
||||
"flag:"
|
||||
)
|
||||
print(f"identify-demo -d {destination}")
|
||||
print("=" * 60)
|
||||
else:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
return
|
||||
|
||||
@ -147,6 +235,7 @@ def main() -> None:
|
||||
"length-prefixed (new format)"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine format: raw format if --raw-format is specified, otherwise
|
||||
@ -154,9 +243,19 @@ def main() -> None:
|
||||
use_varint_format = not args.raw_format
|
||||
|
||||
try:
|
||||
trio.run(run, *(args.port, args.destination, use_varint_format))
|
||||
if args.destination:
|
||||
# Run in dialer mode
|
||||
trio.run(run, *(args.port, args.destination, use_varint_format))
|
||||
else:
|
||||
# Run in listener mode
|
||||
trio.run(run, *(args.port, args.destination, use_varint_format))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
print("\n👋 Goodbye!")
|
||||
logger.info("Application interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {str(e)}")
|
||||
logger.error("Error: %s", str(e))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -11,23 +11,26 @@ This example shows how to:
|
||||
|
||||
import logging
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import (
|
||||
new_host,
|
||||
)
|
||||
from libp2p.abc import (
|
||||
INetStream,
|
||||
)
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.identity.identify import (
|
||||
identify_handler_for,
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
from libp2p.identity.identify_push import (
|
||||
ID_PUSH,
|
||||
identify_push_handler_for,
|
||||
push_identify_to_peer,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
@ -38,8 +41,145 @@ from libp2p.peer.peerinfo import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_custom_identify_handler(host, host_name: str):
|
||||
"""Create a custom identify handler that displays received information."""
|
||||
|
||||
async def handle_identify(stream: INetStream) -> None:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
print(f"\n🔍 {host_name} received identify request from peer: {peer_id}")
|
||||
|
||||
# Get the standard identify response using the existing function
|
||||
from libp2p.identity.identify.identify import (
|
||||
_mk_identify_protobuf,
|
||||
_remote_address_to_multiaddr,
|
||||
)
|
||||
|
||||
# Get observed address
|
||||
observed_multiaddr = None
|
||||
try:
|
||||
remote_address = stream.get_remote_address()
|
||||
if remote_address:
|
||||
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build the identify protobuf
|
||||
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
|
||||
response_data = identify_msg.SerializeToString()
|
||||
|
||||
print(f" 📋 {host_name} identify information:")
|
||||
if identify_msg.HasField("protocol_version"):
|
||||
print(f" Protocol Version: {identify_msg.protocol_version}")
|
||||
if identify_msg.HasField("agent_version"):
|
||||
print(f" Agent Version: {identify_msg.agent_version}")
|
||||
if identify_msg.HasField("public_key"):
|
||||
print(f" Public Key: {identify_msg.public_key.hex()[:16]}...")
|
||||
if identify_msg.listen_addrs:
|
||||
print(" Listen Addresses:")
|
||||
for addr_bytes in identify_msg.listen_addrs:
|
||||
addr = multiaddr.Multiaddr(addr_bytes)
|
||||
print(f" - {addr}")
|
||||
if identify_msg.protocols:
|
||||
print(" Supported Protocols:")
|
||||
for protocol in identify_msg.protocols:
|
||||
print(f" - {protocol}")
|
||||
|
||||
# Send the response
|
||||
await stream.write(response_data)
|
||||
await stream.close()
|
||||
|
||||
return handle_identify
|
||||
|
||||
|
||||
def create_custom_identify_push_handler(host, host_name: str):
|
||||
"""Create a custom identify/push handler that displays received information."""
|
||||
|
||||
async def handle_identify_push(stream: INetStream) -> None:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
print(f"\n📤 {host_name} received identify/push from peer: {peer_id}")
|
||||
|
||||
try:
|
||||
# Read the identify message using the utility function
|
||||
from libp2p.utils.varint import read_length_prefixed_protobuf
|
||||
|
||||
data = await read_length_prefixed_protobuf(stream, use_varint_format=True)
|
||||
|
||||
# Parse the identify message
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
|
||||
print(" 📋 Received identify information:")
|
||||
if identify_msg.HasField("protocol_version"):
|
||||
print(f" Protocol Version: {identify_msg.protocol_version}")
|
||||
if identify_msg.HasField("agent_version"):
|
||||
print(f" Agent Version: {identify_msg.agent_version}")
|
||||
if identify_msg.HasField("public_key"):
|
||||
print(f" Public Key: {identify_msg.public_key.hex()[:16]}...")
|
||||
if identify_msg.HasField("observed_addr") and identify_msg.observed_addr:
|
||||
observed_addr = multiaddr.Multiaddr(identify_msg.observed_addr)
|
||||
print(f" Observed Address: {observed_addr}")
|
||||
if identify_msg.listen_addrs:
|
||||
print(" Listen Addresses:")
|
||||
for addr_bytes in identify_msg.listen_addrs:
|
||||
addr = multiaddr.Multiaddr(addr_bytes)
|
||||
print(f" - {addr}")
|
||||
if identify_msg.protocols:
|
||||
print(" Supported Protocols:")
|
||||
for protocol in identify_msg.protocols:
|
||||
print(f" - {protocol}")
|
||||
|
||||
# Update the peerstore with the new information
|
||||
from libp2p.identity.identify_push.identify_push import (
|
||||
_update_peerstore_from_identify,
|
||||
)
|
||||
|
||||
await _update_peerstore_from_identify(
|
||||
host.get_peerstore(), peer_id, identify_msg
|
||||
)
|
||||
|
||||
print(f" ✅ {host_name} updated peerstore with new information")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Error processing identify/push: {e}")
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
return handle_identify_push
|
||||
|
||||
|
||||
async def display_peerstore_info(host, host_name: str, peer_id, description: str):
|
||||
"""Display peerstore information for a specific peer."""
|
||||
peerstore = host.get_peerstore()
|
||||
|
||||
try:
|
||||
addrs = peerstore.addrs(peer_id)
|
||||
except Exception:
|
||||
addrs = []
|
||||
|
||||
try:
|
||||
protocols = peerstore.get_protocols(peer_id)
|
||||
except Exception:
|
||||
protocols = []
|
||||
|
||||
print(f"\n📚 {host_name} peerstore for {description}:")
|
||||
print(f" Peer ID: {peer_id}")
|
||||
if addrs:
|
||||
print(" Addresses:")
|
||||
for addr in addrs:
|
||||
print(f" - {addr}")
|
||||
else:
|
||||
print(" Addresses: None")
|
||||
|
||||
if protocols:
|
||||
print(" Protocols:")
|
||||
for protocol in protocols:
|
||||
print(f" - {protocol}")
|
||||
else:
|
||||
print(" Protocols: None")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
print("\n==== Starting Identify-Push Example ====\n")
|
||||
print("\n==== Starting Enhanced Identify-Push Example ====\n")
|
||||
|
||||
# Create key pairs for the two hosts
|
||||
key_pair_1 = create_new_key_pair()
|
||||
@ -48,45 +188,49 @@ async def main() -> None:
|
||||
# Create the first host
|
||||
host_1 = new_host(key_pair=key_pair_1)
|
||||
|
||||
# Set up the identify and identify/push handlers
|
||||
host_1.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_1))
|
||||
host_1.set_stream_handler(ID_PUSH, identify_push_handler_for(host_1))
|
||||
# Set up custom identify and identify/push handlers
|
||||
host_1.set_stream_handler(
|
||||
TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_1, "Host 1")
|
||||
)
|
||||
host_1.set_stream_handler(
|
||||
ID_PUSH, create_custom_identify_push_handler(host_1, "Host 1")
|
||||
)
|
||||
|
||||
# Create the second host
|
||||
host_2 = new_host(key_pair=key_pair_2)
|
||||
|
||||
# Set up the identify and identify/push handlers
|
||||
host_2.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_2))
|
||||
host_2.set_stream_handler(ID_PUSH, identify_push_handler_for(host_2))
|
||||
# Set up custom identify and identify/push handlers
|
||||
host_2.set_stream_handler(
|
||||
TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_2, "Host 2")
|
||||
)
|
||||
host_2.set_stream_handler(
|
||||
ID_PUSH, create_custom_identify_push_handler(host_2, "Host 2")
|
||||
)
|
||||
|
||||
# Start listening on random ports using the run context manager
|
||||
import multiaddr
|
||||
|
||||
listen_addr_1 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
listen_addr_2 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
|
||||
async with host_1.run([listen_addr_1]), host_2.run([listen_addr_2]):
|
||||
# Get the addresses of both hosts
|
||||
addr_1 = host_1.get_addrs()[0]
|
||||
logger.info(f"Host 1 listening on {addr_1}")
|
||||
print(f"Host 1 listening on {addr_1}")
|
||||
print(f"Peer ID: {host_1.get_id().pretty()}")
|
||||
|
||||
addr_2 = host_2.get_addrs()[0]
|
||||
logger.info(f"Host 2 listening on {addr_2}")
|
||||
print(f"Host 2 listening on {addr_2}")
|
||||
print(f"Peer ID: {host_2.get_id().pretty()}")
|
||||
|
||||
print("\nConnecting Host 2 to Host 1...")
|
||||
print("🏠 Host Configuration:")
|
||||
print(f" Host 1: {addr_1}")
|
||||
print(f" Host 1 Peer ID: {host_1.get_id().pretty()}")
|
||||
print(f" Host 2: {addr_2}")
|
||||
print(f" Host 2 Peer ID: {host_2.get_id().pretty()}")
|
||||
|
||||
print("\n🔗 Connecting Host 2 to Host 1...")
|
||||
|
||||
# Connect host_2 to host_1
|
||||
peer_info = info_from_p2p_addr(addr_1)
|
||||
await host_2.connect(peer_info)
|
||||
logger.info("Host 2 connected to Host 1")
|
||||
print("Host 2 successfully connected to Host 1")
|
||||
print("✅ Host 2 successfully connected to Host 1")
|
||||
|
||||
# Run the identify protocol from host_2 to host_1
|
||||
# (so Host 1 learns Host 2's address)
|
||||
print("\n🔄 Running identify protocol (Host 2 → Host 1)...")
|
||||
from libp2p.identity.identify.identify import ID as IDENTIFY_PROTOCOL_ID
|
||||
|
||||
stream = await host_2.new_stream(host_1.get_id(), (IDENTIFY_PROTOCOL_ID,))
|
||||
@ -94,64 +238,58 @@ async def main() -> None:
|
||||
await stream.close()
|
||||
|
||||
# Run the identify protocol from host_1 to host_2
|
||||
# (so Host 2 learns Host 1's address)
|
||||
print("\n🔄 Running identify protocol (Host 1 → Host 2)...")
|
||||
stream = await host_1.new_stream(host_2.get_id(), (IDENTIFY_PROTOCOL_ID,))
|
||||
response = await stream.read()
|
||||
await stream.close()
|
||||
|
||||
# --- NEW CODE: Update Host 1's peerstore with Host 2's addresses ---
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
|
||||
# Update Host 1's peerstore with Host 2's addresses
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(response)
|
||||
peerstore_1 = host_1.get_peerstore()
|
||||
peer_id_2 = host_2.get_id()
|
||||
for addr_bytes in identify_msg.listen_addrs:
|
||||
maddr = multiaddr.Multiaddr(addr_bytes)
|
||||
# TTL can be any positive int
|
||||
peerstore_1.add_addr(
|
||||
peer_id_2,
|
||||
maddr,
|
||||
ttl=3600,
|
||||
)
|
||||
# --- END NEW CODE ---
|
||||
peerstore_1.add_addr(peer_id_2, maddr, ttl=3600)
|
||||
|
||||
# Now Host 1's peerstore should have Host 2's address
|
||||
peerstore_1 = host_1.get_peerstore()
|
||||
peer_id_2 = host_2.get_id()
|
||||
addrs_1_for_2 = peerstore_1.addrs(peer_id_2)
|
||||
logger.info(
|
||||
f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: "
|
||||
f"{addrs_1_for_2}"
|
||||
)
|
||||
print(
|
||||
f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: "
|
||||
f"{addrs_1_for_2}"
|
||||
# Display peerstore information before push
|
||||
await display_peerstore_info(
|
||||
host_1, "Host 1", peer_id_2, "Host 2 (before push)"
|
||||
)
|
||||
|
||||
# Push identify information from host_1 to host_2
|
||||
logger.info("Host 1 pushing identify information to Host 2")
|
||||
print("\nHost 1 pushing identify information to Host 2...")
|
||||
print("\n📤 Host 1 pushing identify information to Host 2...")
|
||||
|
||||
try:
|
||||
# Call push_identify_to_peer which now returns a boolean
|
||||
success = await push_identify_to_peer(host_1, host_2.get_id())
|
||||
|
||||
if success:
|
||||
logger.info("Identify push completed successfully")
|
||||
print("Identify push completed successfully!")
|
||||
print("✅ Identify push completed successfully!")
|
||||
else:
|
||||
logger.warning("Identify push didn't complete successfully")
|
||||
print("\nWarning: Identify push didn't complete successfully")
|
||||
print("⚠️ Identify push didn't complete successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during identify push: {str(e)}")
|
||||
print(f"\nError during identify push: {str(e)}")
|
||||
print(f"❌ Error during identify push: {str(e)}")
|
||||
|
||||
# Add this at the end of your async with block:
|
||||
await trio.sleep(0.5) # Give background tasks time to finish
|
||||
# Give a moment for the identify/push processing to complete
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# Display peerstore information after push
|
||||
await display_peerstore_info(host_1, "Host 1", peer_id_2, "Host 2 (after push)")
|
||||
await display_peerstore_info(
|
||||
host_2, "Host 2", host_1.get_id(), "Host 1 (after push)"
|
||||
)
|
||||
|
||||
# Give more time for background tasks to finish and connections to stabilize
|
||||
print("\n⏳ Waiting for background tasks to complete...")
|
||||
await trio.sleep(1.0)
|
||||
|
||||
# Gracefully close connections to prevent connection errors
|
||||
print("🔌 Closing connections...")
|
||||
await host_2.disconnect(host_1.get_id())
|
||||
await trio.sleep(0.2)
|
||||
|
||||
print("\n🎉 Example completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -41,6 +41,9 @@ from libp2p.identity.identify import (
|
||||
ID as ID_IDENTIFY,
|
||||
identify_handler_for,
|
||||
)
|
||||
from libp2p.identity.identify.identify import (
|
||||
_remote_address_to_multiaddr,
|
||||
)
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
@ -72,40 +75,30 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
|
||||
async def handle_identify_push(stream: INetStream) -> None:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
# Get remote address information
|
||||
try:
|
||||
if use_varint_format:
|
||||
# Read length-prefixed identify message from the stream
|
||||
from libp2p.utils.varint import decode_varint_from_bytes
|
||||
remote_address = stream.get_remote_address()
|
||||
if remote_address:
|
||||
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
|
||||
logger.info(
|
||||
"Connection from remote peer %s, address: %s, multiaddr: %s",
|
||||
peer_id,
|
||||
remote_address,
|
||||
observed_multiaddr,
|
||||
)
|
||||
print(f"\n🔗 Received identify/push request from peer: {peer_id}")
|
||||
# Add the peer ID to create a complete multiaddr
|
||||
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
|
||||
print(f" Remote address: {complete_multiaddr}")
|
||||
except Exception as e:
|
||||
logger.error("Error getting remote address: %s", e)
|
||||
print(f"\n🔗 Received identify/push request from peer: {peer_id}")
|
||||
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
break
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
try:
|
||||
# Use the utility function to read the protobuf message
|
||||
from libp2p.utils.varint import read_length_prefixed_protobuf
|
||||
|
||||
if not length_bytes:
|
||||
logger.warning("No length prefix received from peer %s", peer_id)
|
||||
return
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
logger.warning("Incomplete message received from peer %s", peer_id)
|
||||
return
|
||||
else:
|
||||
# Read raw protobuf message from the stream
|
||||
data = b""
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
data += chunk
|
||||
data = await read_length_prefixed_protobuf(stream, use_varint_format)
|
||||
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
@ -155,11 +148,41 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
|
||||
await _update_peerstore_from_identify(peerstore, peer_id, identify_msg)
|
||||
|
||||
logger.info("Successfully processed identify/push from peer %s", peer_id)
|
||||
print(f"\nSuccessfully processed identify/push from peer {peer_id}")
|
||||
print(f"✅ Successfully processed identify/push from peer {peer_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing identify/push from %s: %s", peer_id, e)
|
||||
print(f"\nError processing identify/push from {peer_id}: {e}")
|
||||
error_msg = str(e)
|
||||
logger.error(
|
||||
"Error processing identify/push from %s: %s", peer_id, error_msg
|
||||
)
|
||||
print(f"\nError processing identify/push from {peer_id}: {error_msg}")
|
||||
|
||||
# Check for specific format mismatch errors
|
||||
if (
|
||||
"Error parsing message" in error_msg
|
||||
or "DecodeError" in error_msg
|
||||
or "ParseFromString" in error_msg
|
||||
):
|
||||
print("\n" + "=" * 60)
|
||||
print("FORMAT MISMATCH DETECTED!")
|
||||
print("=" * 60)
|
||||
if use_varint_format:
|
||||
print(
|
||||
"You are using length-prefixed format (default) but the "
|
||||
"dialer is using raw protobuf format."
|
||||
)
|
||||
print("\nTo fix this, run the dialer with the --raw-format flag:")
|
||||
print(
|
||||
"identify-push-listener-dialer-demo --raw-format -d <ADDRESS>"
|
||||
)
|
||||
else:
|
||||
print("You are using raw protobuf format but the dialer")
|
||||
print("is using length-prefixed format (default).")
|
||||
print(
|
||||
"\nTo fix this, run the dialer without the --raw-format flag:"
|
||||
)
|
||||
print("identify-push-listener-dialer-demo -d <ADDRESS>")
|
||||
print("=" * 60)
|
||||
finally:
|
||||
# Close the stream after processing
|
||||
await stream.close()
|
||||
@ -167,7 +190,9 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
|
||||
return handle_identify_push
|
||||
|
||||
|
||||
async def run_listener(port: int, use_varint_format: bool = True) -> None:
|
||||
async def run_listener(
|
||||
port: int, use_varint_format: bool = True, raw_format_flag: bool = False
|
||||
) -> None:
|
||||
"""Run a host in listener mode."""
|
||||
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
|
||||
print(
|
||||
@ -187,29 +212,41 @@ async def run_listener(port: int, use_varint_format: bool = True) -> None:
|
||||
)
|
||||
host.set_stream_handler(
|
||||
ID_IDENTIFY_PUSH,
|
||||
identify_push_handler_for(host, use_varint_format=use_varint_format),
|
||||
custom_identify_push_handler_for(host, use_varint_format=use_varint_format),
|
||||
)
|
||||
|
||||
# Start listening
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
|
||||
async with host.run([listen_addr]):
|
||||
addr = host.get_addrs()[0]
|
||||
logger.info("Listener host ready!")
|
||||
print("Listener host ready!")
|
||||
try:
|
||||
async with host.run([listen_addr]):
|
||||
addr = host.get_addrs()[0]
|
||||
logger.info("Listener host ready!")
|
||||
print("Listener host ready!")
|
||||
|
||||
logger.info(f"Listening on: {addr}")
|
||||
print(f"Listening on: {addr}")
|
||||
logger.info(f"Listening on: {addr}")
|
||||
print(f"Listening on: {addr}")
|
||||
|
||||
logger.info(f"Peer ID: {host.get_id().pretty()}")
|
||||
print(f"Peer ID: {host.get_id().pretty()}")
|
||||
logger.info(f"Peer ID: {host.get_id().pretty()}")
|
||||
print(f"Peer ID: {host.get_id().pretty()}")
|
||||
|
||||
print("\nRun dialer with command:")
|
||||
print(f"identify-push-listener-dialer-demo -d {addr}")
|
||||
print("\nWaiting for incoming connections... (Ctrl+C to exit)")
|
||||
print("\nRun dialer with command:")
|
||||
if raw_format_flag:
|
||||
print(f"identify-push-listener-dialer-demo -d {addr} --raw-format")
|
||||
else:
|
||||
print(f"identify-push-listener-dialer-demo -d {addr}")
|
||||
print("\nWaiting for incoming identify/push requests... (Ctrl+C to exit)")
|
||||
|
||||
# Keep running until interrupted
|
||||
await trio.sleep_forever()
|
||||
# Keep running until interrupted
|
||||
try:
|
||||
await trio.sleep_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("\n🛑 Shutting down listener...")
|
||||
logger.info("Listener interrupted by user")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Listener error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def run_dialer(
|
||||
@ -256,7 +293,9 @@ async def run_dialer(
|
||||
try:
|
||||
await host.connect(peer_info)
|
||||
logger.info("Successfully connected to listener!")
|
||||
print("Successfully connected to listener!")
|
||||
print("✅ Successfully connected to listener!")
|
||||
print(f" Connected to: {peer_info.peer_id}")
|
||||
print(f" Full address: {destination}")
|
||||
|
||||
# Push identify information to the listener
|
||||
logger.info("Pushing identify information to listener...")
|
||||
@ -270,7 +309,7 @@ async def run_dialer(
|
||||
|
||||
if success:
|
||||
logger.info("Identify push completed successfully!")
|
||||
print("Identify push completed successfully!")
|
||||
print("✅ Identify push completed successfully!")
|
||||
|
||||
logger.info("Example completed successfully!")
|
||||
print("\nExample completed successfully!")
|
||||
@ -281,17 +320,57 @@ async def run_dialer(
|
||||
logger.warning("Example completed with warnings.")
|
||||
print("Example completed with warnings.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during identify push: {str(e)}")
|
||||
print(f"\nError during identify push: {str(e)}")
|
||||
error_msg = str(e)
|
||||
logger.error(f"Error during identify push: {error_msg}")
|
||||
print(f"\nError during identify push: {error_msg}")
|
||||
|
||||
# Check for specific format mismatch errors
|
||||
if (
|
||||
"Error parsing message" in error_msg
|
||||
or "DecodeError" in error_msg
|
||||
or "ParseFromString" in error_msg
|
||||
):
|
||||
print("\n" + "=" * 60)
|
||||
print("FORMAT MISMATCH DETECTED!")
|
||||
print("=" * 60)
|
||||
if use_varint_format:
|
||||
print(
|
||||
"You are using length-prefixed format (default) but the "
|
||||
"listener is using raw protobuf format."
|
||||
)
|
||||
print(
|
||||
"\nTo fix this, run the dialer with the --raw-format flag:"
|
||||
)
|
||||
print(
|
||||
f"identify-push-listener-dialer-demo --raw-format -d "
|
||||
f"{destination}"
|
||||
)
|
||||
else:
|
||||
print("You are using raw protobuf format but the listener")
|
||||
print("is using length-prefixed format (default).")
|
||||
print(
|
||||
"\nTo fix this, run the dialer without the --raw-format "
|
||||
"flag:"
|
||||
)
|
||||
print(f"identify-push-listener-dialer-demo -d {destination}")
|
||||
print("=" * 60)
|
||||
|
||||
logger.error("Example completed with errors.")
|
||||
print("Example completed with errors.")
|
||||
# Continue execution despite the push error
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during dialer operation: {str(e)}")
|
||||
print(f"\nError during dialer operation: {str(e)}")
|
||||
raise
|
||||
error_msg = str(e)
|
||||
if "unable to connect" in error_msg or "SwarmException" in error_msg:
|
||||
print(f"\n❌ Cannot connect to peer: {peer_info.peer_id}")
|
||||
print(f" Address: {destination}")
|
||||
print(f" Error: {error_msg}")
|
||||
print("\n💡 Make sure the peer is running and the address is correct.")
|
||||
return
|
||||
else:
|
||||
logger.error(f"Error during dialer operation: {error_msg}")
|
||||
print(f"\nError during dialer operation: {error_msg}")
|
||||
raise
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@ -301,12 +380,21 @@ def main() -> None:
|
||||
Without arguments, it runs as a listener on random port.
|
||||
With -d parameter, it runs as a dialer on random port.
|
||||
|
||||
Port 0 (default) means the OS will automatically assign an available port.
|
||||
This prevents port conflicts when running multiple instances.
|
||||
|
||||
Use --raw-format to send raw protobuf messages (old format) instead of
|
||||
length-prefixed protobuf messages (new format, default).
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--port",
|
||||
default=0,
|
||||
type=int,
|
||||
help="source port number (0 = random available port)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--destination",
|
||||
@ -321,6 +409,7 @@ def main() -> None:
|
||||
"length-prefixed (new format)"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine format: raw format if --raw-format is specified, otherwise
|
||||
@ -333,12 +422,12 @@ def main() -> None:
|
||||
trio.run(run_dialer, args.port, args.destination, use_varint_format)
|
||||
else:
|
||||
# Run in listener mode with random available port if not specified
|
||||
trio.run(run_listener, args.port, use_varint_format)
|
||||
trio.run(run_listener, args.port, use_varint_format, args.raw_format)
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
logger.info("Interrupted by user")
|
||||
print("\n👋 Goodbye!")
|
||||
logger.info("Application interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\nError: {str(e)}")
|
||||
print(f"\n❌ Error: {str(e)}")
|
||||
logger.error("Error: %s", str(e))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@ -113,7 +113,7 @@ def parse_identify_response(response: bytes) -> Identify:
|
||||
|
||||
|
||||
def identify_handler_for(
|
||||
host: IHost, use_varint_format: bool = False
|
||||
host: IHost, use_varint_format: bool = True
|
||||
) -> StreamHandlerFn:
|
||||
async def handle_identify(stream: INetStream) -> None:
|
||||
# get observed address from ``stream``
|
||||
|
||||
@ -28,7 +28,7 @@ from libp2p.utils import (
|
||||
varint,
|
||||
)
|
||||
from libp2p.utils.varint import (
|
||||
decode_varint_from_bytes,
|
||||
read_length_prefixed_protobuf,
|
||||
)
|
||||
|
||||
from ..identify.identify import (
|
||||
@ -66,49 +66,8 @@ def identify_push_handler_for(
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
try:
|
||||
if use_varint_format:
|
||||
# Read length-prefixed identify message from the stream
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
break
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
if not length_bytes:
|
||||
logger.warning("No length prefix received from peer %s", peer_id)
|
||||
return
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
logger.warning("Incomplete message received from peer %s", peer_id)
|
||||
return
|
||||
else:
|
||||
# Read raw protobuf message from the stream
|
||||
# For raw format, we need to read all data before the stream is closed
|
||||
data = b""
|
||||
try:
|
||||
# Read all available data in a single operation
|
||||
data = await stream.read()
|
||||
except StreamClosed:
|
||||
# Try to read any remaining data
|
||||
try:
|
||||
data = await stream.read()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If we got no data, log a warning and return
|
||||
if not data:
|
||||
logger.warning(
|
||||
"No data received in raw format from peer %s", peer_id
|
||||
)
|
||||
return
|
||||
# Use the utility function to read the protobuf message
|
||||
data = await read_length_prefixed_protobuf(stream, use_varint_format)
|
||||
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
@ -119,6 +78,11 @@ def identify_push_handler_for(
|
||||
)
|
||||
|
||||
logger.debug("Successfully processed identify/push from peer %s", peer_id)
|
||||
|
||||
# Send acknowledgment to indicate successful processing
|
||||
# This ensures the sender knows the message was received before closing
|
||||
await stream.write(b"OK")
|
||||
|
||||
except StreamClosed:
|
||||
logger.debug(
|
||||
"Stream closed while processing identify/push from %s", peer_id
|
||||
@ -127,7 +91,10 @@ def identify_push_handler_for(
|
||||
logger.error("Error processing identify/push from %s: %s", peer_id, e)
|
||||
finally:
|
||||
# Close the stream after processing
|
||||
await stream.close()
|
||||
try:
|
||||
await stream.close()
|
||||
except Exception:
|
||||
pass # Ignore errors when closing
|
||||
|
||||
return handle_identify_push
|
||||
|
||||
@ -226,7 +193,20 @@ async def push_identify_to_peer(
|
||||
# Send raw protobuf message
|
||||
await stream.write(response)
|
||||
|
||||
# Close the stream
|
||||
# Wait for acknowledgment from the receiver with timeout
|
||||
# This ensures the message was processed before closing
|
||||
try:
|
||||
with trio.move_on_after(1.0): # 1 second timeout
|
||||
ack = await stream.read(2) # Read "OK" acknowledgment
|
||||
if ack != b"OK":
|
||||
logger.warning(
|
||||
"Unexpected acknowledgment from peer %s: %s", peer_id, ack
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("No acknowledgment received from peer %s: %s", peer_id, e)
|
||||
# Continue anyway, as the message might have been processed
|
||||
|
||||
# Close the stream after acknowledgment (or timeout)
|
||||
await stream.close()
|
||||
|
||||
logger.debug("Successfully pushed identify to peer %s", peer_id)
|
||||
|
||||
@ -102,6 +102,9 @@ class TopicValidator(NamedTuple):
|
||||
is_async: bool
|
||||
|
||||
|
||||
MAX_CONCURRENT_VALIDATORS = 10
|
||||
|
||||
|
||||
class Pubsub(Service, IPubsub):
|
||||
host: IHost
|
||||
|
||||
@ -109,6 +112,7 @@ class Pubsub(Service, IPubsub):
|
||||
|
||||
peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
||||
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
||||
_validator_semaphore: trio.Semaphore
|
||||
|
||||
seen_messages: LastSeenCache
|
||||
|
||||
@ -143,6 +147,7 @@ class Pubsub(Service, IPubsub):
|
||||
msg_id_constructor: Callable[
|
||||
[rpc_pb2.Message], bytes
|
||||
] = get_peer_and_seqno_msg_id,
|
||||
max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS,
|
||||
) -> None:
|
||||
"""
|
||||
Construct a new Pubsub object, which is responsible for handling all
|
||||
@ -168,6 +173,7 @@ class Pubsub(Service, IPubsub):
|
||||
# Therefore, we can only close from the receive side.
|
||||
self.peer_receive_channel = peer_receive
|
||||
self.dead_peer_receive_channel = dead_peer_receive
|
||||
self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count)
|
||||
# Register a notifee
|
||||
self.host.get_network().register_notifee(
|
||||
PubsubNotifee(peer_send, dead_peer_send)
|
||||
@ -657,7 +663,11 @@ class Pubsub(Service, IPubsub):
|
||||
|
||||
logger.debug("successfully published message %s", msg)
|
||||
|
||||
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||
async def validate_msg(
|
||||
self,
|
||||
msg_forwarder: ID,
|
||||
msg: rpc_pb2.Message,
|
||||
) -> None:
|
||||
"""
|
||||
Validate the received message.
|
||||
|
||||
@ -680,23 +690,34 @@ class Pubsub(Service, IPubsub):
|
||||
if not validator(msg_forwarder, msg):
|
||||
raise ValidationError(f"Validation failed for msg={msg}")
|
||||
|
||||
# TODO: Implement throttle on async validators
|
||||
|
||||
if len(async_topic_validators) > 0:
|
||||
# Appends to lists are thread safe in CPython
|
||||
results = []
|
||||
|
||||
async def run_async_validator(func: AsyncValidatorFn) -> None:
|
||||
result = await func(msg_forwarder, msg)
|
||||
results.append(result)
|
||||
results: list[bool] = []
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for async_validator in async_topic_validators:
|
||||
nursery.start_soon(run_async_validator, async_validator)
|
||||
nursery.start_soon(
|
||||
self._run_async_validator,
|
||||
async_validator,
|
||||
msg_forwarder,
|
||||
msg,
|
||||
results,
|
||||
)
|
||||
|
||||
if not all(results):
|
||||
raise ValidationError(f"Validation failed for msg={msg}")
|
||||
|
||||
async def _run_async_validator(
|
||||
self,
|
||||
func: AsyncValidatorFn,
|
||||
msg_forwarder: ID,
|
||||
msg: rpc_pb2.Message,
|
||||
results: list[bool],
|
||||
) -> None:
|
||||
async with self._validator_semaphore:
|
||||
result = await func(msg_forwarder, msg)
|
||||
results.append(result)
|
||||
|
||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||
"""
|
||||
Push a pubsub message to others.
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from types import (
|
||||
TracebackType,
|
||||
)
|
||||
@ -32,6 +34,72 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock that allows multiple concurrent readers
|
||||
or one exclusive writer, implemented using Trio primitives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._readers = 0
|
||||
self._readers_lock = trio.Lock() # Protects access to _readers count
|
||||
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
||||
|
||||
async def acquire_read(self) -> None:
|
||||
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
||||
try:
|
||||
async with self._readers_lock:
|
||||
if self._readers == 0:
|
||||
await self._writer_lock.acquire()
|
||||
self._readers += 1
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
async def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
async with self._readers_lock:
|
||||
if self._readers == 1:
|
||||
self._writer_lock.release()
|
||||
self._readers -= 1
|
||||
|
||||
async def acquire_write(self) -> None:
|
||||
"""Acquire an exclusive write lock."""
|
||||
try:
|
||||
await self._writer_lock.acquire()
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release the exclusive write lock."""
|
||||
self._writer_lock.release()
|
||||
|
||||
@asynccontextmanager
|
||||
async def read_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a read lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_read()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
with trio.CancelScope() as scope:
|
||||
scope.shield = True
|
||||
await self.release_read()
|
||||
|
||||
@asynccontextmanager
|
||||
async def write_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a write lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_write()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
self.release_write()
|
||||
|
||||
|
||||
class MplexStream(IMuxedStream):
|
||||
"""
|
||||
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
||||
@ -46,7 +114,7 @@ class MplexStream(IMuxedStream):
|
||||
read_deadline: int | None
|
||||
write_deadline: int | None
|
||||
|
||||
# TODO: Add lock for read/write to avoid interleaving receiving messages?
|
||||
rw_lock: ReadWriteLock
|
||||
close_lock: trio.Lock
|
||||
|
||||
# NOTE: `dataIn` is size of 8 in Go implementation.
|
||||
@ -80,6 +148,7 @@ class MplexStream(IMuxedStream):
|
||||
self.event_remote_closed = trio.Event()
|
||||
self.event_reset = trio.Event()
|
||||
self.close_lock = trio.Lock()
|
||||
self.rw_lock = ReadWriteLock()
|
||||
self.incoming_data_channel = incoming_data_channel
|
||||
self._buf = bytearray()
|
||||
|
||||
@ -113,48 +182,49 @@ class MplexStream(IMuxedStream):
|
||||
:param n: number of bytes to read
|
||||
:return: bytes actually read
|
||||
"""
|
||||
if n is not None and n < 0:
|
||||
raise ValueError(
|
||||
"the number of bytes to read `n` must be non-negative or "
|
||||
f"`None` to indicate read until EOF, got n={n}"
|
||||
)
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if n is None:
|
||||
return await self._read_until_eof()
|
||||
if len(self._buf) == 0:
|
||||
data: bytes
|
||||
# Peek whether there is data available. If yes, we just read until there is
|
||||
# no data, then return.
|
||||
try:
|
||||
data = self.incoming_data_channel.receive_nowait()
|
||||
self._buf.extend(data)
|
||||
except trio.EndOfChannel:
|
||||
raise MplexStreamEOF
|
||||
except trio.WouldBlock:
|
||||
# We know `receive` will be blocked here. Wait for data here with
|
||||
# `receive` and catch all kinds of errors here.
|
||||
async with self.rw_lock.read_lock():
|
||||
if n is not None and n < 0:
|
||||
raise ValueError(
|
||||
"the number of bytes to read `n` must be non-negative or "
|
||||
f"`None` to indicate read until EOF, got n={n}"
|
||||
)
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if n is None:
|
||||
return await self._read_until_eof()
|
||||
if len(self._buf) == 0:
|
||||
data: bytes
|
||||
# Peek whether there is data available. If yes, we just read until
|
||||
# there is no data, then return.
|
||||
try:
|
||||
data = await self.incoming_data_channel.receive()
|
||||
data = self.incoming_data_channel.receive_nowait()
|
||||
self._buf.extend(data)
|
||||
except trio.EndOfChannel:
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if self.event_remote_closed.is_set():
|
||||
raise MplexStreamEOF
|
||||
except trio.ClosedResourceError as error:
|
||||
# Probably `incoming_data_channel` is closed in `reset` when we are
|
||||
# waiting for `receive`.
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
raise Exception(
|
||||
"`incoming_data_channel` is closed but stream is not reset. "
|
||||
"This should never happen."
|
||||
) from error
|
||||
self._buf.extend(self._read_return_when_blocked())
|
||||
payload = self._buf[:n]
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
raise MplexStreamEOF
|
||||
except trio.WouldBlock:
|
||||
# We know `receive` will be blocked here. Wait for data here with
|
||||
# `receive` and catch all kinds of errors here.
|
||||
try:
|
||||
data = await self.incoming_data_channel.receive()
|
||||
self._buf.extend(data)
|
||||
except trio.EndOfChannel:
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if self.event_remote_closed.is_set():
|
||||
raise MplexStreamEOF
|
||||
except trio.ClosedResourceError as error:
|
||||
# Probably `incoming_data_channel` is closed in `reset` when
|
||||
# we are waiting for `receive`.
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
raise Exception(
|
||||
"`incoming_data_channel` is closed but stream is not reset."
|
||||
"This should never happen."
|
||||
) from error
|
||||
self._buf.extend(self._read_return_when_blocked())
|
||||
payload = self._buf[:n]
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""
|
||||
@ -162,14 +232,15 @@ class MplexStream(IMuxedStream):
|
||||
|
||||
:return: number of bytes written
|
||||
"""
|
||||
if self.event_local_closed.is_set():
|
||||
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
|
||||
flag = (
|
||||
HeaderTags.MessageInitiator
|
||||
if self.is_initiator
|
||||
else HeaderTags.MessageReceiver
|
||||
)
|
||||
await self.muxed_conn.send_message(flag, data, self.stream_id)
|
||||
async with self.rw_lock.write_lock():
|
||||
if self.event_local_closed.is_set():
|
||||
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
|
||||
flag = (
|
||||
HeaderTags.MessageInitiator
|
||||
if self.is_initiator
|
||||
else HeaderTags.MessageReceiver
|
||||
)
|
||||
await self.muxed_conn.send_message(flag, data, self.stream_id)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
|
||||
@ -45,6 +45,9 @@ from libp2p.stream_muxer.exceptions import (
|
||||
MuxedStreamReset,
|
||||
)
|
||||
|
||||
# Configure logger for this module
|
||||
logger = logging.getLogger("libp2p.stream_muxer.yamux")
|
||||
|
||||
PROTOCOL_ID = "/yamux/1.0.0"
|
||||
TYPE_DATA = 0x0
|
||||
TYPE_WINDOW_UPDATE = 0x1
|
||||
@ -98,13 +101,13 @@ class YamuxStream(IMuxedStream):
|
||||
# Flow control: Check if we have enough send window
|
||||
total_len = len(data)
|
||||
sent = 0
|
||||
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
while sent < total_len:
|
||||
# Wait for available window with timeout
|
||||
timeout = False
|
||||
async with self.window_lock:
|
||||
if self.send_window == 0:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Window is zero, waiting for update"
|
||||
)
|
||||
# Release lock and wait with timeout
|
||||
@ -152,12 +155,12 @@ class YamuxStream(IMuxedStream):
|
||||
"""
|
||||
if increment <= 0:
|
||||
# If increment is zero or negative, skip sending update
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Skipping window update"
|
||||
f"(increment={increment})"
|
||||
)
|
||||
return
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Sending window update with increment={increment}"
|
||||
)
|
||||
|
||||
@ -185,7 +188,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
# If the stream is closed for receiving and the buffer is empty, raise EOF
|
||||
if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id):
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Stream closed for receiving and buffer empty"
|
||||
)
|
||||
raise MuxedStreamEOF("Stream is closed for receiving")
|
||||
@ -198,7 +201,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
# If buffer is not available, check if stream is closed
|
||||
if buffer is None:
|
||||
logging.debug(f"Stream {self.stream_id}: No buffer available")
|
||||
logger.debug(f"Stream {self.stream_id}: No buffer available")
|
||||
raise MuxedStreamEOF("Stream buffer closed")
|
||||
|
||||
# If we have data in buffer, process it
|
||||
@ -210,34 +213,34 @@ class YamuxStream(IMuxedStream):
|
||||
# Send window update for the chunk we just read
|
||||
async with self.window_lock:
|
||||
self.recv_window += len(chunk)
|
||||
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
|
||||
logger.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
|
||||
await self.send_window_update(len(chunk), skip_lock=True)
|
||||
|
||||
# If stream is closed (FIN received) and buffer is empty, break
|
||||
if self.recv_closed and len(buffer) == 0:
|
||||
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer")
|
||||
logger.debug(f"Stream {self.stream_id}: Closed with empty buffer")
|
||||
break
|
||||
|
||||
# If stream was reset, raise reset error
|
||||
if self.reset_received:
|
||||
logging.debug(f"Stream {self.stream_id}: Stream was reset")
|
||||
logger.debug(f"Stream {self.stream_id}: Stream was reset")
|
||||
raise MuxedStreamReset("Stream was reset")
|
||||
|
||||
# Wait for more data or stream closure
|
||||
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
|
||||
logger.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
|
||||
await self.conn.stream_events[self.stream_id].wait()
|
||||
self.conn.stream_events[self.stream_id] = trio.Event()
|
||||
|
||||
# After loop exit, first check if we have data to return
|
||||
if data:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
|
||||
)
|
||||
return data
|
||||
|
||||
# No data accumulated, now check why we exited the loop
|
||||
if self.conn.event_shutting_down.is_set():
|
||||
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
|
||||
logger.debug(f"Stream {self.stream_id}: Connection shutting down")
|
||||
raise MuxedStreamEOF("Connection shut down")
|
||||
|
||||
# Return empty data
|
||||
@ -246,7 +249,7 @@ class YamuxStream(IMuxedStream):
|
||||
data = await self.conn.read_stream(self.stream_id, n)
|
||||
async with self.window_lock:
|
||||
self.recv_window += len(data)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Sending window update after read, "
|
||||
f"increment={len(data)}"
|
||||
)
|
||||
@ -255,7 +258,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self.send_closed:
|
||||
logging.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
||||
)
|
||||
@ -271,7 +274,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
async def reset(self) -> None:
|
||||
if not self.closed:
|
||||
logging.debug(f"Resetting stream {self.stream_id}")
|
||||
logger.debug(f"Resetting stream {self.stream_id}")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
||||
)
|
||||
@ -349,7 +352,7 @@ class Yamux(IMuxedConn):
|
||||
self._nursery: Nursery | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
logging.debug(f"Starting Yamux for {self.peer_id}")
|
||||
logger.debug(f"Starting Yamux for {self.peer_id}")
|
||||
if self.event_started.is_set():
|
||||
return
|
||||
async with trio.open_nursery() as nursery:
|
||||
@ -362,7 +365,7 @@ class Yamux(IMuxedConn):
|
||||
return self.is_initiator_value
|
||||
|
||||
async def close(self, error_code: int = GO_AWAY_NORMAL) -> None:
|
||||
logging.debug(f"Closing Yamux connection with code {error_code}")
|
||||
logger.debug(f"Closing Yamux connection with code {error_code}")
|
||||
async with self.streams_lock:
|
||||
if not self.event_shutting_down.is_set():
|
||||
try:
|
||||
@ -371,7 +374,7 @@ class Yamux(IMuxedConn):
|
||||
)
|
||||
await self.secured_conn.write(header)
|
||||
except Exception as e:
|
||||
logging.debug(f"Failed to send GO_AWAY: {e}")
|
||||
logger.debug(f"Failed to send GO_AWAY: {e}")
|
||||
self.event_shutting_down.set()
|
||||
for stream in self.streams.values():
|
||||
stream.closed = True
|
||||
@ -382,12 +385,12 @@ class Yamux(IMuxedConn):
|
||||
self.stream_events.clear()
|
||||
try:
|
||||
await self.secured_conn.close()
|
||||
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
except Exception as e:
|
||||
logging.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
|
||||
logger.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
|
||||
self.event_closed.set()
|
||||
if self.on_close:
|
||||
logging.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
|
||||
logger.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
|
||||
if inspect.iscoroutinefunction(self.on_close):
|
||||
if self.on_close is not None:
|
||||
await self.on_close()
|
||||
@ -416,7 +419,7 @@ class Yamux(IMuxedConn):
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0
|
||||
)
|
||||
logging.debug(f"Sending SYN header for stream {stream_id}")
|
||||
logger.debug(f"Sending SYN header for stream {stream_id}")
|
||||
await self.secured_conn.write(header)
|
||||
return stream
|
||||
except Exception as e:
|
||||
@ -424,32 +427,32 @@ class Yamux(IMuxedConn):
|
||||
raise e
|
||||
|
||||
async def accept_stream(self) -> IMuxedStream:
|
||||
logging.debug("Waiting for new stream")
|
||||
logger.debug("Waiting for new stream")
|
||||
try:
|
||||
stream = await self.new_stream_receive_channel.receive()
|
||||
logging.debug(f"Received stream {stream.stream_id}")
|
||||
logger.debug(f"Received stream {stream.stream_id}")
|
||||
return stream
|
||||
except trio.EndOfChannel:
|
||||
raise MuxedStreamError("No new streams available")
|
||||
|
||||
async def read_stream(self, stream_id: int, n: int = -1) -> bytes:
|
||||
logging.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
|
||||
logger.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
|
||||
if n is None:
|
||||
n = -1
|
||||
|
||||
while True:
|
||||
async with self.streams_lock:
|
||||
if stream_id not in self.streams:
|
||||
logging.debug(f"Stream {self.peer_id}:{stream_id} unknown")
|
||||
logger.debug(f"Stream {self.peer_id}:{stream_id} unknown")
|
||||
raise MuxedStreamEOF("Stream closed")
|
||||
if self.event_shutting_down.is_set():
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}: connection shutting down"
|
||||
)
|
||||
raise MuxedStreamEOF("Connection shut down")
|
||||
stream = self.streams[stream_id]
|
||||
buffer = self.stream_buffers.get(stream_id)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}: "
|
||||
f"closed={stream.closed}, "
|
||||
f"recv_closed={stream.recv_closed}, "
|
||||
@ -457,7 +460,7 @@ class Yamux(IMuxedConn):
|
||||
f"buffer_len={len(buffer) if buffer else 0}"
|
||||
)
|
||||
if buffer is None:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"Buffer gone, assuming closed"
|
||||
)
|
||||
@ -470,7 +473,7 @@ class Yamux(IMuxedConn):
|
||||
else:
|
||||
data = bytes(buffer[:n])
|
||||
del buffer[:n]
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Returning {len(data)} bytes"
|
||||
f"from stream {self.peer_id}:{stream_id}, "
|
||||
f"buffer_len={len(buffer)}"
|
||||
@ -478,7 +481,7 @@ class Yamux(IMuxedConn):
|
||||
return data
|
||||
# If reset received and buffer is empty, raise reset
|
||||
if stream.reset_received:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"reset_received=True, raising MuxedStreamReset"
|
||||
)
|
||||
@ -491,7 +494,7 @@ class Yamux(IMuxedConn):
|
||||
else:
|
||||
data = bytes(buffer[:n])
|
||||
del buffer[:n]
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Returning {len(data)} bytes"
|
||||
f"from stream {self.peer_id}:{stream_id}, "
|
||||
f"buffer_len={len(buffer)}"
|
||||
@ -499,21 +502,21 @@ class Yamux(IMuxedConn):
|
||||
return data
|
||||
# Check if stream is closed
|
||||
if stream.closed:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"closed=True, raising MuxedStreamReset"
|
||||
)
|
||||
raise MuxedStreamReset("Stream is reset or closed")
|
||||
# Check if recv_closed and buffer empty
|
||||
if stream.recv_closed:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"recv_closed=True, buffer empty, raising EOF"
|
||||
)
|
||||
raise MuxedStreamEOF("Stream is closed for receiving")
|
||||
|
||||
# Wait for data if stream is still open
|
||||
logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
|
||||
logger.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
|
||||
try:
|
||||
await self.stream_events[stream_id].wait()
|
||||
self.stream_events[stream_id] = trio.Event()
|
||||
@ -528,7 +531,7 @@ class Yamux(IMuxedConn):
|
||||
try:
|
||||
header = await self.secured_conn.read(HEADER_SIZE)
|
||||
if not header or len(header) < HEADER_SIZE:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Connection closed orincomplete header for peer {self.peer_id}"
|
||||
)
|
||||
self.event_shutting_down.set()
|
||||
@ -537,7 +540,7 @@ class Yamux(IMuxedConn):
|
||||
version, typ, flags, stream_id, length = struct.unpack(
|
||||
YAMUX_HEADER_FORMAT, header
|
||||
)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received header for peer {self.peer_id}:"
|
||||
f"type={typ}, flags={flags}, stream_id={stream_id},"
|
||||
f"length={length}"
|
||||
@ -558,7 +561,7 @@ class Yamux(IMuxedConn):
|
||||
0,
|
||||
)
|
||||
await self.secured_conn.write(ack_header)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Sending stream {stream_id}"
|
||||
f"to channel for peer {self.peer_id}"
|
||||
)
|
||||
@ -576,7 +579,7 @@ class Yamux(IMuxedConn):
|
||||
elif typ == TYPE_DATA and flags & FLAG_RST:
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Resetting stream {stream_id} for peer {self.peer_id}"
|
||||
)
|
||||
self.streams[stream_id].closed = True
|
||||
@ -585,27 +588,27 @@ class Yamux(IMuxedConn):
|
||||
elif typ == TYPE_DATA and flags & FLAG_ACK:
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received ACK for stream"
|
||||
f"{stream_id} for peer {self.peer_id}"
|
||||
)
|
||||
elif typ == TYPE_GO_AWAY:
|
||||
error_code = length
|
||||
if error_code == GO_AWAY_NORMAL:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received GO_AWAY for peer"
|
||||
f"{self.peer_id}: Normal termination"
|
||||
)
|
||||
elif error_code == GO_AWAY_PROTOCOL_ERROR:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Received GO_AWAY for peer{self.peer_id}: Protocol error"
|
||||
)
|
||||
elif error_code == GO_AWAY_INTERNAL_ERROR:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Received GO_AWAY for peer {self.peer_id}: Internal error"
|
||||
)
|
||||
else:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Received GO_AWAY for peer {self.peer_id}"
|
||||
f"with unknown error code: {error_code}"
|
||||
)
|
||||
@ -614,7 +617,7 @@ class Yamux(IMuxedConn):
|
||||
break
|
||||
elif typ == TYPE_PING:
|
||||
if flags & FLAG_SYN:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received ping request with value"
|
||||
f"{length} for peer {self.peer_id}"
|
||||
)
|
||||
@ -623,7 +626,7 @@ class Yamux(IMuxedConn):
|
||||
)
|
||||
await self.secured_conn.write(ping_header)
|
||||
elif flags & FLAG_ACK:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received ping response with value"
|
||||
f"{length} for peer {self.peer_id}"
|
||||
)
|
||||
@ -637,7 +640,7 @@ class Yamux(IMuxedConn):
|
||||
self.stream_buffers[stream_id].extend(data)
|
||||
self.stream_events[stream_id].set()
|
||||
if flags & FLAG_FIN:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received FIN for stream {self.peer_id}:"
|
||||
f"{stream_id}, marking recv_closed"
|
||||
)
|
||||
@ -645,7 +648,7 @@ class Yamux(IMuxedConn):
|
||||
if self.streams[stream_id].send_closed:
|
||||
self.streams[stream_id].closed = True
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading data for stream {stream_id}: {e}")
|
||||
logger.error(f"Error reading data for stream {stream_id}: {e}")
|
||||
# Mark stream as closed on read error
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
@ -659,7 +662,7 @@ class Yamux(IMuxedConn):
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
async with stream.window_lock:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received window update for stream"
|
||||
f"{self.peer_id}:{stream_id},"
|
||||
f" increment: {increment}"
|
||||
@ -674,7 +677,7 @@ class Yamux(IMuxedConn):
|
||||
and details.get("requested_count") == 2
|
||||
and details.get("received_count") == 0
|
||||
):
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Stream closed cleanly for peer {self.peer_id}"
|
||||
+ f" (IncompleteReadError: {details})"
|
||||
)
|
||||
@ -682,15 +685,32 @@ class Yamux(IMuxedConn):
|
||||
await self._cleanup_on_error()
|
||||
break
|
||||
else:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Error in handle_incoming for peer {self.peer_id}: "
|
||||
+ f"{type(e).__name__}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
logging.error(
|
||||
f"Error in handle_incoming for peer {self.peer_id}: "
|
||||
+ f"{type(e).__name__}: {str(e)}"
|
||||
)
|
||||
# Handle RawConnError with more nuance
|
||||
if isinstance(e, RawConnError):
|
||||
error_msg = str(e)
|
||||
# If RawConnError is empty, it's likely normal cleanup
|
||||
if not error_msg.strip():
|
||||
logger.info(
|
||||
f"RawConnError (empty) during cleanup for peer "
|
||||
f"{self.peer_id} (normal connection shutdown)"
|
||||
)
|
||||
else:
|
||||
# Log non-empty RawConnError as warning
|
||||
logger.warning(
|
||||
f"RawConnError during connection handling for peer "
|
||||
f"{self.peer_id}: {error_msg}"
|
||||
)
|
||||
else:
|
||||
# Log all other errors normally
|
||||
logger.error(
|
||||
f"Error in handle_incoming for peer {self.peer_id}: "
|
||||
+ f"{type(e).__name__}: {str(e)}"
|
||||
)
|
||||
# Don't crash the whole connection for temporary errors
|
||||
if self.event_shutting_down.is_set() or isinstance(
|
||||
e, (RawConnError, OSError)
|
||||
@ -720,9 +740,9 @@ class Yamux(IMuxedConn):
|
||||
# Close the secured connection
|
||||
try:
|
||||
await self.secured_conn.close()
|
||||
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
except Exception as close_error:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Error closing secured_conn for peer {self.peer_id}: {close_error}"
|
||||
)
|
||||
|
||||
@ -731,14 +751,14 @@ class Yamux(IMuxedConn):
|
||||
|
||||
# Call on_close callback if provided
|
||||
if self.on_close:
|
||||
logging.debug(f"Calling on_close for peer {self.peer_id}")
|
||||
logger.debug(f"Calling on_close for peer {self.peer_id}")
|
||||
try:
|
||||
if inspect.iscoroutinefunction(self.on_close):
|
||||
await self.on_close()
|
||||
else:
|
||||
self.on_close()
|
||||
except Exception as callback_error:
|
||||
logging.error(f"Error in on_close callback: {callback_error}")
|
||||
logger.error(f"Error in on_close callback: {callback_error}")
|
||||
|
||||
# Cancel nursery tasks
|
||||
if self._nursery:
|
||||
|
||||
@ -9,6 +9,7 @@ from libp2p.utils.varint import (
|
||||
read_varint_prefixed_bytes,
|
||||
decode_varint_from_bytes,
|
||||
decode_varint_with_size,
|
||||
read_length_prefixed_protobuf,
|
||||
)
|
||||
from libp2p.utils.version import (
|
||||
get_agent_version,
|
||||
@ -24,4 +25,5 @@ __all__ = [
|
||||
"read_varint_prefixed_bytes",
|
||||
"decode_varint_from_bytes",
|
||||
"decode_varint_with_size",
|
||||
"read_length_prefixed_protobuf",
|
||||
]
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
from typing import BinaryIO
|
||||
|
||||
from libp2p.abc import INetStream
|
||||
from libp2p.exceptions import (
|
||||
ParseError,
|
||||
)
|
||||
@ -25,42 +27,41 @@ HIGH_MASK = 2**7
|
||||
SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7
|
||||
|
||||
|
||||
def encode_uvarint(number: int) -> bytes:
|
||||
"""Pack `number` into varint bytes."""
|
||||
buf = b""
|
||||
while True:
|
||||
towrite = number & 0x7F
|
||||
number >>= 7
|
||||
if number:
|
||||
buf += bytes((towrite | 0x80,))
|
||||
else:
|
||||
buf += bytes((towrite,))
|
||||
def encode_uvarint(value: int) -> bytes:
|
||||
"""Encode an unsigned integer as a varint."""
|
||||
if value < 0:
|
||||
raise ValueError("Cannot encode negative value as uvarint")
|
||||
|
||||
result = bytearray()
|
||||
while value >= 0x80:
|
||||
result.append((value & 0x7F) | 0x80)
|
||||
value >>= 7
|
||||
result.append(value & 0x7F)
|
||||
return bytes(result)
|
||||
|
||||
|
||||
def decode_uvarint(data: bytes) -> int:
|
||||
"""Decode a varint from bytes."""
|
||||
if not data:
|
||||
raise ParseError("Unexpected end of data")
|
||||
|
||||
result = 0
|
||||
shift = 0
|
||||
|
||||
for byte in data:
|
||||
result |= (byte & 0x7F) << shift
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
return buf
|
||||
shift += 7
|
||||
if shift >= 64:
|
||||
raise ValueError("Varint too long")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def decode_varint_from_bytes(data: bytes) -> int:
|
||||
"""
|
||||
Decode a varint from bytes and return the value.
|
||||
|
||||
This is a synchronous version of decode_uvarint_from_stream for already-read bytes.
|
||||
"""
|
||||
res = 0
|
||||
for shift in itertools.count(0, 7):
|
||||
if shift > SHIFT_64_BIT_MAX:
|
||||
raise ParseError("Integer is too large...")
|
||||
|
||||
if not data:
|
||||
raise ParseError("Unexpected end of data")
|
||||
|
||||
value = data[0]
|
||||
data = data[1:]
|
||||
|
||||
res += (value & LOW_MASK) << shift
|
||||
|
||||
if not value & HIGH_MASK:
|
||||
break
|
||||
return res
|
||||
"""Decode a varint from bytes (alias for decode_uvarint for backward comp)."""
|
||||
return decode_uvarint(data)
|
||||
|
||||
|
||||
async def decode_uvarint_from_stream(reader: Reader) -> int:
|
||||
@ -84,34 +85,33 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
|
||||
|
||||
def decode_varint_with_size(data: bytes) -> tuple[int, int]:
|
||||
"""
|
||||
Decode a varint from bytes and return (value, bytes_consumed).
|
||||
Returns (0, 0) if the data doesn't start with a valid varint.
|
||||
Decode a varint from bytes and return both the value and the number of bytes
|
||||
consumed.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: (value, bytes_consumed)
|
||||
|
||||
"""
|
||||
try:
|
||||
# Calculate how many bytes the varint consumes
|
||||
varint_size = 0
|
||||
for i, byte in enumerate(data):
|
||||
varint_size += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
result = 0
|
||||
shift = 0
|
||||
bytes_consumed = 0
|
||||
|
||||
if varint_size == 0:
|
||||
return 0, 0
|
||||
for byte in data:
|
||||
result |= (byte & 0x7F) << shift
|
||||
bytes_consumed += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
shift += 7
|
||||
if shift >= 64:
|
||||
raise ValueError("Varint too long")
|
||||
|
||||
# Extract just the varint bytes
|
||||
varint_bytes = data[:varint_size]
|
||||
|
||||
# Decode the varint
|
||||
value = decode_varint_from_bytes(varint_bytes)
|
||||
|
||||
return value, varint_size
|
||||
except Exception:
|
||||
return 0, 0
|
||||
return result, bytes_consumed
|
||||
|
||||
|
||||
def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
|
||||
varint_len = encode_uvarint(len(msg_bytes))
|
||||
return varint_len + msg_bytes
|
||||
def encode_varint_prefixed(data: bytes) -> bytes:
|
||||
"""Encode data with a varint length prefix."""
|
||||
length_bytes = encode_uvarint(len(data))
|
||||
return length_bytes + data
|
||||
|
||||
|
||||
async def read_varint_prefixed_bytes(reader: Reader) -> bytes:
|
||||
@ -138,3 +138,95 @@ async def read_delim(reader: Reader) -> bytes:
|
||||
f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes!r}'
|
||||
)
|
||||
return msg_bytes[:-1]
|
||||
|
||||
|
||||
def read_varint_prefixed_bytes_sync(
|
||||
stream: BinaryIO, max_length: int = 1024 * 1024
|
||||
) -> bytes:
|
||||
"""
|
||||
Read varint-prefixed bytes from a stream.
|
||||
|
||||
Args:
|
||||
stream: A stream-like object with a read() method
|
||||
max_length: Maximum allowed data length to prevent memory exhaustion
|
||||
|
||||
Returns:
|
||||
bytes: The data without the length prefix
|
||||
|
||||
Raises:
|
||||
ValueError: If the length prefix is invalid or too large
|
||||
EOFError: If the stream ends unexpectedly
|
||||
|
||||
"""
|
||||
# Read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
byte_data = stream.read(1)
|
||||
if not byte_data:
|
||||
raise EOFError("Stream ended while reading varint length prefix")
|
||||
|
||||
length_bytes += byte_data
|
||||
if byte_data[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
# Decode the length
|
||||
length = decode_uvarint(length_bytes)
|
||||
|
||||
if length > max_length:
|
||||
raise ValueError(f"Data length {length} exceeds maximum allowed {max_length}")
|
||||
|
||||
# Read the data
|
||||
data = stream.read(length)
|
||||
if len(data) != length:
|
||||
raise EOFError(f"Expected {length} bytes, got {len(data)}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def read_length_prefixed_protobuf(
|
||||
stream: INetStream, use_varint_format: bool = True, max_length: int = 1024 * 1024
|
||||
) -> bytes:
|
||||
"""Read a protobuf message from a stream, handling both formats."""
|
||||
if use_varint_format:
|
||||
# Read length-prefixed protobuf message from the stream
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
raise Exception("No length prefix received")
|
||||
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
if msg_length > max_length:
|
||||
raise Exception(
|
||||
f"Message length {msg_length} exceeds maximum allowed {max_length}"
|
||||
)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
raise Exception(
|
||||
f"Incomplete message: expected {msg_length}, got {len(data)}"
|
||||
)
|
||||
|
||||
return data
|
||||
else:
|
||||
# Read raw protobuf message from the stream
|
||||
# For raw format, read all available data in one go
|
||||
data = await stream.read()
|
||||
|
||||
# If we got no data, raise an exception
|
||||
if not data:
|
||||
raise Exception("No data received in raw format")
|
||||
|
||||
if len(data) > max_length:
|
||||
raise Exception(
|
||||
f"Message length {len(data)} exceeds maximum allowed {max_length}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
1
newsfragments/748.feature.rst
Normal file
1
newsfragments/748.feature.rst
Normal file
@ -0,0 +1 @@
|
||||
Add lock for read/write to avoid interleaving receiving messages in mplex_stream.py
|
||||
2
newsfragments/755.performance.rst
Normal file
2
newsfragments/755.performance.rst
Normal file
@ -0,0 +1,2 @@
|
||||
Added throttling for async topic validators in validate_msg, enforcing a
|
||||
concurrency limit to prevent resource exhaustion under heavy load.
|
||||
1
newsfragments/766.internal.rst
Normal file
1
newsfragments/766.internal.rst
Normal file
@ -0,0 +1 @@
|
||||
Pin py-multiaddr dependency to specific git commit db8124e2321f316d3b7d2733c7df11d6ad9c03e6
|
||||
1
newsfragments/775.docs.rst
Normal file
1
newsfragments/775.docs.rst
Normal file
@ -0,0 +1 @@
|
||||
Clarified the requirement for a trailing newline in newsfragments to pass lint checks.
|
||||
1
newsfragments/778.bugfix.rst
Normal file
1
newsfragments/778.bugfix.rst
Normal file
@ -0,0 +1 @@
|
||||
Fixed incorrect handling of raw protobuf format in identify protocol. The identify example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats.
|
||||
1
newsfragments/784.bugfix.rst
Normal file
1
newsfragments/784.bugfix.rst
Normal file
@ -0,0 +1 @@
|
||||
Fixed incorrect handling of raw protobuf format in identify push protocol. The identify push example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats.
|
||||
1
newsfragments/784.internal.rst
Normal file
1
newsfragments/784.internal.rst
Normal file
@ -0,0 +1 @@
|
||||
Yamux RawConnError Logging Refactor - Improved error handling and debug logging
|
||||
@ -18,12 +18,19 @@ Each file should be named like `<ISSUE>.<TYPE>.rst`, where
|
||||
- `performance`
|
||||
- `removal`
|
||||
|
||||
So for example: `123.feature.rst`, `456.bugfix.rst`
|
||||
So for example: `1024.feature.rst`
|
||||
|
||||
**Important**: Ensure the file ends with a newline character (`\n`) to pass GitHub tox linting checks.
|
||||
|
||||
```
|
||||
Added support for Ed25519 key generation in libp2p peer identity creation.
|
||||
|
||||
```
|
||||
|
||||
If the PR fixes an issue, use that number here. If there is no issue,
|
||||
then open up the PR first and use the PR number for the newsfragment.
|
||||
|
||||
Note that the `towncrier` tool will automatically
|
||||
**Note** that the `towncrier` tool will automatically
|
||||
reflow your text, so don't try to do any fancy formatting. Run
|
||||
`towncrier build --draft` to get a preview of what the release notes entry
|
||||
will look like in the final release notes.
|
||||
|
||||
@ -19,7 +19,8 @@ dependencies = [
|
||||
"exceptiongroup>=1.2.0; python_version < '3.11'",
|
||||
"grpcio>=1.41.0",
|
||||
"lru-dict>=1.1.6",
|
||||
"multiaddr>=0.0.9",
|
||||
# "multiaddr>=0.0.9",
|
||||
"multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@db8124e2321f316d3b7d2733c7df11d6ad9c03e6",
|
||||
"mypy-protobuf>=3.0.0",
|
||||
"noiseprotocol>=0.3.0",
|
||||
"protobuf>=4.21.0,<5.0.0",
|
||||
|
||||
241
tests/core/identity/identify/test_identify_integration.py
Normal file
241
tests/core/identity/identify/test_identify_integration.py
Normal file
@ -0,0 +1,241 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.identity.identify.identify import (
|
||||
AGENT_VERSION,
|
||||
ID,
|
||||
PROTOCOL_VERSION,
|
||||
_multiaddr_to_bytes,
|
||||
identify_handler_for,
|
||||
parse_identify_response,
|
||||
)
|
||||
from tests.utils.factories import host_pair_factory
|
||||
|
||||
logger = logging.getLogger("libp2p.identity.identify-integration-test")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_protocol_varint_format_integration(security_protocol):
|
||||
"""Test identify protocol with varint format in real network scenario."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
host_a.set_stream_handler(
|
||||
ID, identify_handler_for(host_a, use_varint_format=True)
|
||||
)
|
||||
|
||||
# Make identify request
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
response = await stream.read(8192)
|
||||
await stream.close()
|
||||
|
||||
# Parse response
|
||||
result = parse_identify_response(response)
|
||||
|
||||
# Verify response content
|
||||
assert result.agent_version == AGENT_VERSION
|
||||
assert result.protocol_version == PROTOCOL_VERSION
|
||||
assert result.public_key == host_a.get_public_key().serialize()
|
||||
assert result.listen_addrs == [
|
||||
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_protocol_raw_format_integration(security_protocol):
|
||||
"""Test identify protocol with raw format in real network scenario."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
host_a.set_stream_handler(
|
||||
ID, identify_handler_for(host_a, use_varint_format=False)
|
||||
)
|
||||
|
||||
# Make identify request
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
response = await stream.read(8192)
|
||||
await stream.close()
|
||||
|
||||
# Parse response
|
||||
result = parse_identify_response(response)
|
||||
|
||||
# Verify response content
|
||||
assert result.agent_version == AGENT_VERSION
|
||||
assert result.protocol_version == PROTOCOL_VERSION
|
||||
assert result.public_key == host_a.get_public_key().serialize()
|
||||
assert result.listen_addrs == [
|
||||
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_default_format_behavior(security_protocol):
|
||||
"""Test identify protocol uses correct default format."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Use default identify handler (should use varint format)
|
||||
host_a.set_stream_handler(ID, identify_handler_for(host_a))
|
||||
|
||||
# Make identify request
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
response = await stream.read(8192)
|
||||
await stream.close()
|
||||
|
||||
# Parse response
|
||||
result = parse_identify_response(response)
|
||||
|
||||
# Verify response content
|
||||
assert result.agent_version == AGENT_VERSION
|
||||
assert result.protocol_version == PROTOCOL_VERSION
|
||||
assert result.public_key == host_a.get_public_key().serialize()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_cross_format_compatibility_varint_to_raw(security_protocol):
|
||||
"""Test varint dialer with raw listener compatibility."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Host A uses raw format
|
||||
host_a.set_stream_handler(
|
||||
ID, identify_handler_for(host_a, use_varint_format=False)
|
||||
)
|
||||
|
||||
# Host B makes request (will automatically detect format)
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
response = await stream.read(8192)
|
||||
await stream.close()
|
||||
|
||||
# Parse response (should work with automatic format detection)
|
||||
result = parse_identify_response(response)
|
||||
|
||||
# Verify response content
|
||||
assert result.agent_version == AGENT_VERSION
|
||||
assert result.protocol_version == PROTOCOL_VERSION
|
||||
assert result.public_key == host_a.get_public_key().serialize()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_cross_format_compatibility_raw_to_varint(security_protocol):
|
||||
"""Test raw dialer with varint listener compatibility."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Host A uses varint format
|
||||
host_a.set_stream_handler(
|
||||
ID, identify_handler_for(host_a, use_varint_format=True)
|
||||
)
|
||||
|
||||
# Host B makes request (will automatically detect format)
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
response = await stream.read(8192)
|
||||
await stream.close()
|
||||
|
||||
# Parse response (should work with automatic format detection)
|
||||
result = parse_identify_response(response)
|
||||
|
||||
# Verify response content
|
||||
assert result.agent_version == AGENT_VERSION
|
||||
assert result.protocol_version == PROTOCOL_VERSION
|
||||
assert result.public_key == host_a.get_public_key().serialize()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_format_detection_robustness(security_protocol):
|
||||
"""Test identify protocol format detection is robust with various message sizes."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Test both formats with different message sizes
|
||||
for use_varint in [True, False]:
|
||||
host_a.set_stream_handler(
|
||||
ID, identify_handler_for(host_a, use_varint_format=use_varint)
|
||||
)
|
||||
|
||||
# Make identify request
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
response = await stream.read(8192)
|
||||
await stream.close()
|
||||
|
||||
# Parse response
|
||||
result = parse_identify_response(response)
|
||||
|
||||
# Verify response content
|
||||
assert result.agent_version == AGENT_VERSION
|
||||
assert result.protocol_version == PROTOCOL_VERSION
|
||||
assert result.public_key == host_a.get_public_key().serialize()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_large_message_handling(security_protocol):
|
||||
"""Test identify protocol handles large messages with many protocols."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add many protocols to make the message larger
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
for i in range(10):
|
||||
host_a.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler)
|
||||
|
||||
host_a.set_stream_handler(
|
||||
ID, identify_handler_for(host_a, use_varint_format=True)
|
||||
)
|
||||
|
||||
# Make identify request
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
response = await stream.read(8192)
|
||||
await stream.close()
|
||||
|
||||
# Parse response
|
||||
result = parse_identify_response(response)
|
||||
|
||||
# Verify response content
|
||||
assert result.agent_version == AGENT_VERSION
|
||||
assert result.protocol_version == PROTOCOL_VERSION
|
||||
assert result.public_key == host_a.get_public_key().serialize()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_message_equivalence_real_network(security_protocol):
|
||||
"""Test that both formats produce equivalent messages in real network."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Test varint format
|
||||
host_a.set_stream_handler(
|
||||
ID, identify_handler_for(host_a, use_varint_format=True)
|
||||
)
|
||||
stream_varint = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
response_varint = await stream_varint.read(8192)
|
||||
await stream_varint.close()
|
||||
|
||||
# Test raw format
|
||||
host_a.set_stream_handler(
|
||||
ID, identify_handler_for(host_a, use_varint_format=False)
|
||||
)
|
||||
stream_raw = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
response_raw = await stream_raw.read(8192)
|
||||
await stream_raw.close()
|
||||
|
||||
# Parse both responses
|
||||
result_varint = parse_identify_response(response_varint)
|
||||
result_raw = parse_identify_response(response_raw)
|
||||
|
||||
# Both should produce identical parsed results
|
||||
assert result_varint.agent_version == result_raw.agent_version
|
||||
assert result_varint.protocol_version == result_raw.protocol_version
|
||||
assert result_varint.public_key == result_raw.public_key
|
||||
assert result_varint.listen_addrs == result_raw.listen_addrs
|
||||
@ -1,410 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.identity.identify.identify import (
|
||||
_mk_identify_protobuf,
|
||||
)
|
||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
||||
Identify,
|
||||
)
|
||||
from libp2p.io.abc import Closer, Reader, Writer
|
||||
from libp2p.utils.varint import (
|
||||
decode_varint_from_bytes,
|
||||
encode_varint_prefixed,
|
||||
)
|
||||
from tests.utils.factories import (
|
||||
host_pair_factory,
|
||||
)
|
||||
|
||||
|
||||
class MockStream(Reader, Writer, Closer):
|
||||
"""Mock stream for testing identify protocol compatibility."""
|
||||
|
||||
def __init__(self, data: bytes):
|
||||
self.data = data
|
||||
self.position = 0
|
||||
self.closed = False
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
if self.closed or self.position >= len(self.data):
|
||||
return b""
|
||||
if n is None:
|
||||
n = len(self.data) - self.position
|
||||
result = self.data[self.position : self.position + n]
|
||||
self.position += len(result)
|
||||
return result
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
# Mock write - just store the data
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
def create_identify_message(host, observed_multiaddr=None):
|
||||
"""Create an identify protobuf message."""
|
||||
return _mk_identify_protobuf(host, observed_multiaddr)
|
||||
|
||||
|
||||
def create_new_format_message(identify_msg):
|
||||
"""Create a new format (length-prefixed) identify message."""
|
||||
msg_bytes = identify_msg.SerializeToString()
|
||||
return encode_varint_prefixed(msg_bytes)
|
||||
|
||||
|
||||
def create_old_format_message(identify_msg):
|
||||
"""Create an old format (raw protobuf) identify message."""
|
||||
return identify_msg.SerializeToString()
|
||||
|
||||
|
||||
async def read_new_format_message(stream) -> bytes:
|
||||
"""Read a new format (length-prefixed) identify message."""
|
||||
# Read varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
break
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
if not length_bytes:
|
||||
raise ValueError("No length prefix received")
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
# Read the protobuf message
|
||||
response = await stream.read(msg_length)
|
||||
if len(response) != msg_length:
|
||||
raise ValueError("Incomplete message received")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def read_old_format_message(stream) -> bytes:
|
||||
"""Read an old format (raw protobuf) identify message."""
|
||||
# Read all available data
|
||||
response = b""
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
response += chunk
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def read_compatible_message(stream) -> bytes:
|
||||
"""Read an identify message in either old or new format."""
|
||||
# Try to read a few bytes to detect the format
|
||||
first_bytes = await stream.read(10)
|
||||
if not first_bytes:
|
||||
raise ValueError("No data received")
|
||||
|
||||
# Try to decode as varint length prefix (new format)
|
||||
try:
|
||||
msg_length = decode_varint_from_bytes(first_bytes)
|
||||
|
||||
# Validate that the length is reasonable (not too large)
|
||||
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
|
||||
# Calculate how many bytes the varint consumed
|
||||
varint_len = 0
|
||||
for i, byte in enumerate(first_bytes):
|
||||
varint_len += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
|
||||
# Read the remaining protobuf message
|
||||
remaining_bytes = await stream.read(
|
||||
msg_length - (len(first_bytes) - varint_len)
|
||||
)
|
||||
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
|
||||
message_data = first_bytes[varint_len:] + remaining_bytes
|
||||
|
||||
# Try to parse as protobuf to validate
|
||||
try:
|
||||
Identify().ParseFromString(message_data)
|
||||
return message_data
|
||||
except Exception:
|
||||
# If protobuf parsing fails, fall back to old format
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fall back to old format (raw protobuf)
|
||||
response = first_bytes
|
||||
|
||||
# Read more data if available
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
response += chunk
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def read_compatible_message_simple(stream) -> bytes:
|
||||
"""Read a message in either old or new format (simplified version for testing)."""
|
||||
# Try to read a few bytes to detect the format
|
||||
first_bytes = await stream.read(10)
|
||||
if not first_bytes:
|
||||
raise ValueError("No data received")
|
||||
|
||||
# Try to decode as varint length prefix (new format)
|
||||
try:
|
||||
msg_length = decode_varint_from_bytes(first_bytes)
|
||||
|
||||
# Validate that the length is reasonable (not too large)
|
||||
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
|
||||
# Calculate how many bytes the varint consumed
|
||||
varint_len = 0
|
||||
for i, byte in enumerate(first_bytes):
|
||||
varint_len += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
|
||||
# Read the remaining message
|
||||
remaining_bytes = await stream.read(
|
||||
msg_length - (len(first_bytes) - varint_len)
|
||||
)
|
||||
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
|
||||
return first_bytes[varint_len:] + remaining_bytes
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fall back to old format (raw data)
|
||||
response = first_bytes
|
||||
|
||||
# Read more data if available
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
response += chunk
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def detect_format(data):
|
||||
"""Detect if data is in new or old format (varint-prefixed or raw protobuf)."""
|
||||
if not data:
|
||||
return "unknown"
|
||||
|
||||
# Try to decode as varint
|
||||
try:
|
||||
msg_length = decode_varint_from_bytes(data)
|
||||
|
||||
# Validate that the length is reasonable
|
||||
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
|
||||
# Calculate varint length
|
||||
varint_len = 0
|
||||
for i, byte in enumerate(data):
|
||||
varint_len += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
|
||||
# Check if we have enough data for the message
|
||||
if len(data) >= varint_len + msg_length:
|
||||
# Additional check: try to parse the message as protobuf
|
||||
try:
|
||||
message_data = data[varint_len : varint_len + msg_length]
|
||||
Identify().ParseFromString(message_data)
|
||||
return "new"
|
||||
except Exception:
|
||||
# If protobuf parsing fails, it's probably not a valid new format
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If varint decoding fails or length is unreasonable, assume old format
|
||||
return "old"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_new_format_compatibility(security_protocol):
|
||||
"""Test that identify protocol works with new format (length-prefixed) messages."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create identify message
|
||||
identify_msg = create_identify_message(host_a)
|
||||
|
||||
# Create new format message
|
||||
new_format_data = create_new_format_message(identify_msg)
|
||||
|
||||
# Create mock stream with new format data
|
||||
stream = MockStream(new_format_data)
|
||||
|
||||
# Read using new format reader
|
||||
response = await read_new_format_message(stream)
|
||||
|
||||
# Parse the response
|
||||
parsed_msg = Identify()
|
||||
parsed_msg.ParseFromString(response)
|
||||
|
||||
# Verify the message content
|
||||
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
||||
assert parsed_msg.agent_version == identify_msg.agent_version
|
||||
assert parsed_msg.public_key == identify_msg.public_key
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_old_format_compatibility(security_protocol):
|
||||
"""Test that identify protocol works with old format (raw protobuf) messages."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create identify message
|
||||
identify_msg = create_identify_message(host_a)
|
||||
|
||||
# Create old format message
|
||||
old_format_data = create_old_format_message(identify_msg)
|
||||
|
||||
# Create mock stream with old format data
|
||||
stream = MockStream(old_format_data)
|
||||
|
||||
# Read using old format reader
|
||||
response = await read_old_format_message(stream)
|
||||
|
||||
# Parse the response
|
||||
parsed_msg = Identify()
|
||||
parsed_msg.ParseFromString(response)
|
||||
|
||||
# Verify the message content
|
||||
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
||||
assert parsed_msg.agent_version == identify_msg.agent_version
|
||||
assert parsed_msg.public_key == identify_msg.public_key
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_backward_compatibility_old_format(security_protocol):
|
||||
"""Test backward compatibility reader with old format messages."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create identify message
|
||||
identify_msg = create_identify_message(host_a)
|
||||
|
||||
# Create old format message
|
||||
old_format_data = create_old_format_message(identify_msg)
|
||||
|
||||
# Create mock stream with old format data
|
||||
stream = MockStream(old_format_data)
|
||||
|
||||
# Read using old format reader (which should work reliably)
|
||||
response = await read_old_format_message(stream)
|
||||
|
||||
# Parse the response
|
||||
parsed_msg = Identify()
|
||||
parsed_msg.ParseFromString(response)
|
||||
|
||||
# Verify the message content
|
||||
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
||||
assert parsed_msg.agent_version == identify_msg.agent_version
|
||||
assert parsed_msg.public_key == identify_msg.public_key
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_backward_compatibility_new_format(security_protocol):
|
||||
"""Test backward compatibility reader with new format messages."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create identify message
|
||||
identify_msg = create_identify_message(host_a)
|
||||
|
||||
# Create new format message
|
||||
new_format_data = create_new_format_message(identify_msg)
|
||||
|
||||
# Create mock stream with new format data
|
||||
stream = MockStream(new_format_data)
|
||||
|
||||
# Read using new format reader (which should work reliably)
|
||||
response = await read_new_format_message(stream)
|
||||
|
||||
# Parse the response
|
||||
parsed_msg = Identify()
|
||||
parsed_msg.ParseFromString(response)
|
||||
|
||||
# Verify the message content
|
||||
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
||||
assert parsed_msg.agent_version == identify_msg.agent_version
|
||||
assert parsed_msg.public_key == identify_msg.public_key
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_format_detection(security_protocol):
|
||||
"""Test that the format detection works correctly."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create identify message
|
||||
identify_msg = create_identify_message(host_a)
|
||||
|
||||
# Test new format detection
|
||||
new_format_data = create_new_format_message(identify_msg)
|
||||
format_type = detect_format(new_format_data)
|
||||
assert format_type == "new", "New format should be detected correctly"
|
||||
|
||||
# Test old format detection
|
||||
old_format_data = create_old_format_message(identify_msg)
|
||||
format_type = detect_format(old_format_data)
|
||||
assert format_type == "old", "Old format should be detected correctly"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_error_handling(security_protocol):
|
||||
"""Test error handling for malformed messages."""
|
||||
from libp2p.exceptions import ParseError
|
||||
|
||||
# Test with empty data
|
||||
stream = MockStream(b"")
|
||||
with pytest.raises(ValueError, match="No data received"):
|
||||
await read_compatible_message(stream)
|
||||
|
||||
# Test with incomplete varint
|
||||
stream = MockStream(b"\x80") # Incomplete varint
|
||||
with pytest.raises(ParseError, match="Unexpected end of data"):
|
||||
await read_new_format_message(stream)
|
||||
|
||||
# Test with invalid protobuf data
|
||||
stream = MockStream(b"\x05invalid") # Length prefix but invalid protobuf
|
||||
with pytest.raises(Exception): # Should fail when parsing protobuf
|
||||
response = await read_new_format_message(stream)
|
||||
Identify().ParseFromString(response)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_message_equivalence(security_protocol):
|
||||
"""Test that old and new format messages are equivalent."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create identify message
|
||||
identify_msg = create_identify_message(host_a)
|
||||
|
||||
# Create both formats
|
||||
new_format_data = create_new_format_message(identify_msg)
|
||||
old_format_data = create_old_format_message(identify_msg)
|
||||
|
||||
# Extract the protobuf message from new format
|
||||
varint_len = 0
|
||||
for i, byte in enumerate(new_format_data):
|
||||
varint_len += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
|
||||
new_format_protobuf = new_format_data[varint_len:]
|
||||
|
||||
# The protobuf messages should be identical
|
||||
assert new_format_protobuf == old_format_data, (
|
||||
"Protobuf messages should be identical in both formats"
|
||||
)
|
||||
@ -0,0 +1,552 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.identity.identify_push.identify_push import (
|
||||
ID_PUSH,
|
||||
identify_push_handler_for,
|
||||
push_identify_to_peer,
|
||||
push_identify_to_peers,
|
||||
)
|
||||
from tests.utils.factories import host_pair_factory
|
||||
|
||||
logger = logging.getLogger("libp2p.identity.identify-push-integration-test")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_protocol_varint_format_integration(security_protocol):
|
||||
"""Test identify/push protocol with varint format in real network scenario."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to host_b so it has something to push
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/1"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/2"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
|
||||
)
|
||||
|
||||
# Push identify information from host_b to host_a
|
||||
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
|
||||
# Check that addresses were added
|
||||
addrs = peerstore_a.addrs(peer_id_b)
|
||||
assert len(addrs) > 0
|
||||
|
||||
# Check that protocols were added
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
# The protocols should include the dummy protocols we added
|
||||
assert len(protocols) >= 2 # Should include the dummy protocols
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_protocol_raw_format_integration(security_protocol):
|
||||
"""Test identify/push protocol with raw format in real network scenario."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=False)
|
||||
)
|
||||
|
||||
# Push identify information from host_b to host_a
|
||||
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=False)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
|
||||
# Check that addresses were added
|
||||
addrs = peerstore_a.addrs(peer_id_b)
|
||||
assert len(addrs) > 0
|
||||
|
||||
# Check that protocols were added
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
assert len(protocols) >= 1 # Should include the dummy protocol
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_default_format_behavior(security_protocol):
|
||||
"""Test identify/push protocol uses correct default format."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Use default identify/push handler (should use varint format)
|
||||
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
|
||||
|
||||
# Push identify information from host_b to host_a
|
||||
await push_identify_to_peer(host_b, host_a.get_id())
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
|
||||
# Check that protocols were added
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
assert len(protocols) >= 1 # Should include the dummy protocol
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_cross_format_compatibility_varint_to_raw(
|
||||
security_protocol,
|
||||
):
|
||||
"""Test varint pusher with raw listener compatibility."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Use an event to signal when handler is ready
|
||||
handler_ready = trio.Event()
|
||||
|
||||
# Create a wrapper handler that signals when ready
|
||||
original_handler = identify_push_handler_for(host_a, use_varint_format=False)
|
||||
|
||||
async def wrapped_handler(stream):
|
||||
handler_ready.set() # Signal that handler is ready
|
||||
await original_handler(stream)
|
||||
|
||||
# Host A uses raw format with wrapped handler
|
||||
host_a.set_stream_handler(ID_PUSH, wrapped_handler)
|
||||
|
||||
# Host B pushes with varint format (should fail gracefully)
|
||||
success = await push_identify_to_peer(
|
||||
host_b, host_a.get_id(), use_varint_format=True
|
||||
)
|
||||
# This should fail due to format mismatch
|
||||
# Note: The format detection might be more robust than expected
|
||||
# so we just check that the operation completes
|
||||
assert isinstance(success, bool)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_cross_format_compatibility_raw_to_varint(
|
||||
security_protocol,
|
||||
):
|
||||
"""Test raw pusher with varint listener compatibility."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Use an event to signal when handler is ready
|
||||
handler_ready = trio.Event()
|
||||
|
||||
# Create a wrapper handler that signals when ready
|
||||
original_handler = identify_push_handler_for(host_a, use_varint_format=True)
|
||||
|
||||
async def wrapped_handler(stream):
|
||||
handler_ready.set() # Signal that handler is ready
|
||||
await original_handler(stream)
|
||||
|
||||
# Host A uses varint format with wrapped handler
|
||||
host_a.set_stream_handler(ID_PUSH, wrapped_handler)
|
||||
|
||||
# Host B pushes with raw format (should fail gracefully)
|
||||
success = await push_identify_to_peer(
|
||||
host_b, host_a.get_id(), use_varint_format=False
|
||||
)
|
||||
# This should fail due to format mismatch
|
||||
# Note: The format detection might be more robust than expected
|
||||
# so we just check that the operation completes
|
||||
assert isinstance(success, bool)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_multiple_peers_integration(security_protocol):
|
||||
"""Test identify/push protocol with multiple peers."""
|
||||
# Create two hosts using the factory
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create a third host following the pattern from test_identify_push.py
|
||||
import multiaddr
|
||||
|
||||
from libp2p import new_host
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
|
||||
# Create a new key pair for host_c
|
||||
key_pair_c = create_new_key_pair()
|
||||
host_c = new_host(key_pair=key_pair_c)
|
||||
|
||||
# Set up identify/push handlers on all hosts
|
||||
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
|
||||
host_b.set_stream_handler(ID_PUSH, identify_push_handler_for(host_b))
|
||||
host_c.set_stream_handler(ID_PUSH, identify_push_handler_for(host_c))
|
||||
|
||||
# Start listening on a random port using the run context manager
|
||||
listen_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
async with host_c.run([listen_addr]):
|
||||
# Connect host_c to host_a and host_b using the correct pattern
|
||||
await host_c.connect(info_from_p2p_addr(host_a.get_addrs()[0]))
|
||||
await host_c.connect(info_from_p2p_addr(host_b.get_addrs()[0]))
|
||||
|
||||
# Push identify information from host_a to all connected peers
|
||||
await push_identify_to_peers(host_a)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Check that host_b's peerstore has been updated with host_a's information
|
||||
peerstore_b = host_b.get_peerstore()
|
||||
peer_id_a = host_a.get_id()
|
||||
|
||||
# Check that the peer is in the peerstore
|
||||
assert peer_id_a in peerstore_b.peer_ids()
|
||||
|
||||
# Check that host_c's peerstore has been updated with host_a's information
|
||||
peerstore_c = host_c.get_peerstore()
|
||||
|
||||
# Check that the peer is in the peerstore
|
||||
assert peer_id_a in peerstore_c.peer_ids()
|
||||
|
||||
# Test for push_identify to only connected peers and not all peers
|
||||
# Disconnect a from c.
|
||||
await host_c.disconnect(host_a.get_id())
|
||||
|
||||
await push_identify_to_peers(host_c)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Check that host_a's peerstore has not been updated with host_c's info
|
||||
assert host_c.get_id() not in host_a.get_peerstore().peer_ids()
|
||||
# Check that host_b's peerstore has been updated with host_c's info
|
||||
assert host_c.get_id() in host_b.get_peerstore().peer_ids()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_large_message_handling(security_protocol):
|
||||
"""Test identify/push protocol handles large messages with many protocols."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add many protocols to make the message larger
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
for i in range(10):
|
||||
host_b.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler)
|
||||
|
||||
# Also add some protocols to host_a to ensure it has protocols to push
|
||||
for i in range(5):
|
||||
host_a.set_stream_handler(TProtocol(f"/test/protocol/a{i}"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
|
||||
)
|
||||
|
||||
# Push identify information from host_b to host_a
|
||||
success = await push_identify_to_peer(
|
||||
host_b, host_a.get_id(), use_varint_format=True
|
||||
)
|
||||
assert success
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated with all protocols
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
assert len(protocols) >= 10 # Should include the dummy protocols
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_peerstore_update_completeness(security_protocol):
|
||||
"""Test that identify/push updates all relevant peerstore information."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
|
||||
|
||||
# Push identify information from host_b to host_a
|
||||
await push_identify_to_peer(host_b, host_a.get_id())
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
|
||||
# Check that protocols were added
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
assert len(protocols) > 0
|
||||
|
||||
# Check that addresses were added
|
||||
addrs = peerstore_a.addrs(peer_id_b)
|
||||
assert len(addrs) > 0
|
||||
|
||||
# Check that public key was added
|
||||
pubkey = peerstore_a.pubkey(peer_id_b)
|
||||
assert pubkey is not None
|
||||
assert pubkey.serialize() == host_b.get_public_key().serialize()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_concurrent_requests(security_protocol):
|
||||
"""Test identify/push protocol handles concurrent requests properly."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
|
||||
|
||||
# Make multiple concurrent push requests
|
||||
results = []
|
||||
|
||||
async def push_identify():
|
||||
result = await push_identify_to_peer(host_b, host_a.get_id())
|
||||
results.append(result)
|
||||
|
||||
# Run multiple concurrent pushes using nursery
|
||||
async with trio.open_nursery() as nursery:
|
||||
for _ in range(3):
|
||||
nursery.start_soon(push_identify)
|
||||
|
||||
# All should succeed
|
||||
assert len(results) == 3
|
||||
assert all(results)
|
||||
|
||||
# Wait a bit for the pushes to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
assert len(protocols) > 0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_stream_handling(security_protocol):
|
||||
"""Test identify/push protocol properly handles stream lifecycle."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
|
||||
|
||||
# Push identify information from host_b to host_a
|
||||
success = await push_identify_to_peer(host_b, host_a.get_id())
|
||||
assert success
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
protocols = peerstore_a.get_protocols(peer_id_b)
|
||||
assert protocols is not None
|
||||
assert len(protocols) > 0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_error_handling(security_protocol):
|
||||
"""Test identify/push protocol handles errors gracefully."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Create a handler that raises an exception but catches it to prevent test
|
||||
# failure
|
||||
async def error_handler(stream):
|
||||
try:
|
||||
await stream.close()
|
||||
raise Exception("Test error")
|
||||
except Exception:
|
||||
# Catch the exception to prevent it from propagating up
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(ID_PUSH, error_handler)
|
||||
|
||||
# Push should complete (message sent) but handler should fail gracefully
|
||||
success = await push_identify_to_peer(host_b, host_a.get_id())
|
||||
assert success # The push operation itself succeeds (message sent)
|
||||
|
||||
# Wait a bit for the handler to process
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that the error was handled gracefully (no test failure)
|
||||
# The handler caught the exception and didn't propagate it
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_message_equivalence_real_network(security_protocol):
|
||||
"""Test that both formats produce equivalent peerstore updates in real network."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Test varint format
|
||||
host_a.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
|
||||
)
|
||||
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Get peerstore state after varint push
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
protocols_varint = peerstore_a.get_protocols(peer_id_b)
|
||||
addrs_varint = peerstore_a.addrs(peer_id_b)
|
||||
|
||||
# Clear peerstore for next test
|
||||
peerstore_a.clear_addrs(peer_id_b)
|
||||
peerstore_a.clear_protocol_data(peer_id_b)
|
||||
|
||||
# Test raw format
|
||||
host_a.set_stream_handler(
|
||||
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=False)
|
||||
)
|
||||
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=False)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Get peerstore state after raw push
|
||||
protocols_raw = peerstore_a.get_protocols(peer_id_b)
|
||||
addrs_raw = peerstore_a.addrs(peer_id_b)
|
||||
|
||||
# Both should produce equivalent peerstore updates
|
||||
# Check that both formats successfully updated protocols
|
||||
assert protocols_varint is not None
|
||||
assert protocols_raw is not None
|
||||
assert len(protocols_varint) > 0
|
||||
assert len(protocols_raw) > 0
|
||||
|
||||
# Check that both formats successfully updated addresses
|
||||
assert addrs_varint is not None
|
||||
assert addrs_raw is not None
|
||||
assert len(addrs_varint) > 0
|
||||
assert len(addrs_raw) > 0
|
||||
|
||||
# Both should contain the same essential information
|
||||
# (exact address lists might differ due to format-specific handling)
|
||||
assert set(protocols_varint) == set(protocols_raw)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_identify_push_with_observed_address(security_protocol):
|
||||
"""Test identify/push protocol includes observed address information."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Add some protocols to both hosts
|
||||
async def dummy_handler(stream):
|
||||
pass
|
||||
|
||||
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
|
||||
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
|
||||
|
||||
# Set up identify/push handler on host_a
|
||||
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
|
||||
|
||||
# Get host_b's address as observed by host_a
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
host_b_addr = host_b.get_addrs()[0]
|
||||
observed_multiaddr = Multiaddr(str(host_b_addr))
|
||||
|
||||
# Push identify information with observed address
|
||||
await push_identify_to_peer(
|
||||
host_b, host_a.get_id(), observed_multiaddr=observed_multiaddr
|
||||
)
|
||||
|
||||
# Wait a bit for the push to complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Verify that host_a's peerstore was updated
|
||||
peerstore_a = host_a.get_peerstore()
|
||||
peer_id_b = host_b.get_id()
|
||||
|
||||
# Check that addresses were added
|
||||
addrs = peerstore_a.addrs(peer_id_b)
|
||||
assert len(addrs) > 0
|
||||
|
||||
# The observed address should be among the stored addresses
|
||||
addr_strings = [str(addr) for addr in addrs]
|
||||
assert str(observed_multiaddr) in addr_strings
|
||||
@ -5,10 +5,12 @@ import inspect
|
||||
from typing import (
|
||||
NamedTuple,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.custom_types import AsyncValidatorFn
|
||||
from libp2p.exceptions import (
|
||||
ValidationError,
|
||||
)
|
||||
@ -243,7 +245,37 @@ async def test_get_msg_validators():
|
||||
((False, True), (True, False), (True, True)),
|
||||
)
|
||||
@pytest.mark.trio
|
||||
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
|
||||
async def test_validate_msg_with_throttle_condition(
|
||||
is_topic_1_val_passed, is_topic_2_val_passed
|
||||
):
|
||||
CONCURRENCY_LIMIT = 10
|
||||
|
||||
state = {
|
||||
"concurrency_counter": 0,
|
||||
"max_observed": 0,
|
||||
}
|
||||
lock = trio.Lock()
|
||||
|
||||
async def mock_run_async_validator(
|
||||
self,
|
||||
func: AsyncValidatorFn,
|
||||
msg_forwarder: ID,
|
||||
msg: rpc_pb2.Message,
|
||||
results: list[bool],
|
||||
) -> None:
|
||||
async with self._validator_semaphore:
|
||||
async with lock:
|
||||
state["concurrency_counter"] += 1
|
||||
if state["concurrency_counter"] > state["max_observed"]:
|
||||
state["max_observed"] = state["concurrency_counter"]
|
||||
|
||||
try:
|
||||
result = await func(msg_forwarder, msg)
|
||||
results.append(result)
|
||||
finally:
|
||||
async with lock:
|
||||
state["concurrency_counter"] -= 1
|
||||
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
|
||||
def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||||
@ -280,11 +312,19 @@ async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
if is_topic_1_val_passed and is_topic_2_val_passed:
|
||||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
with patch(
|
||||
"libp2p.pubsub.pubsub.Pubsub._run_async_validator",
|
||||
new=mock_run_async_validator,
|
||||
):
|
||||
if is_topic_1_val_passed and is_topic_2_val_passed:
|
||||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||
|
||||
assert state["max_observed"] <= CONCURRENCY_LIMIT, (
|
||||
f"Max concurrency observed: {state['max_observed']}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
|
||||
590
tests/core/stream_muxer/test_read_write_lock.py
Normal file
590
tests/core/stream_muxer/test_read_write_lock.py
Normal file
@ -0,0 +1,590 @@
|
||||
from typing import Any, cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
from trio.testing import wait_all_tasks_blocked
|
||||
|
||||
from libp2p.stream_muxer.exceptions import (
|
||||
MuxedConnUnavailable,
|
||||
)
|
||||
from libp2p.stream_muxer.mplex.constants import HeaderTags
|
||||
from libp2p.stream_muxer.mplex.datastructures import StreamID
|
||||
from libp2p.stream_muxer.mplex.exceptions import (
|
||||
MplexStreamClosed,
|
||||
MplexStreamEOF,
|
||||
MplexStreamReset,
|
||||
)
|
||||
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
|
||||
|
||||
|
||||
class MockMuxedConn:
|
||||
"""A mock Mplex connection for testing purposes."""
|
||||
|
||||
def __init__(self):
|
||||
self.sent_messages = []
|
||||
self.streams: dict[StreamID, MplexStream] = {}
|
||||
self.streams_lock = trio.Lock()
|
||||
self.is_unavailable = False
|
||||
|
||||
async def send_message(
|
||||
self, flag: HeaderTags, data: bytes | None, stream_id: StreamID
|
||||
) -> None:
|
||||
"""Mocks sending a message over the connection."""
|
||||
if self.is_unavailable:
|
||||
raise MuxedConnUnavailable("Connection is unavailable")
|
||||
self.sent_messages.append((flag, data, stream_id))
|
||||
# Yield to allow other tasks to run
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int]:
|
||||
"""Mocks getting the remote address."""
|
||||
return "127.0.0.1", 4001
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mplex_stream():
|
||||
"""Provides a fully initialized MplexStream and its communication channels."""
|
||||
# Use a buffered channel to prevent deadlocks in simple tests
|
||||
send_chan, recv_chan = trio.open_memory_channel(10)
|
||||
stream_id = StreamID(1, is_initiator=True)
|
||||
muxed_conn = MockMuxedConn()
|
||||
stream = MplexStream("test-stream", stream_id, cast(Any, muxed_conn), recv_chan)
|
||||
muxed_conn.streams[stream_id] = stream
|
||||
|
||||
yield stream, send_chan, muxed_conn
|
||||
|
||||
# Cleanup: Close channels and reset stream state
|
||||
await send_chan.aclose()
|
||||
await recv_chan.aclose()
|
||||
# Reset stream state to prevent cross-test contamination
|
||||
stream.event_local_closed = trio.Event()
|
||||
stream.event_remote_closed = trio.Event()
|
||||
stream.event_reset = trio.Event()
|
||||
|
||||
|
||||
# ===============================================
|
||||
# 1. Tests for Stream-Level Lock Integration
|
||||
# ===============================================
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stream_write_is_protected_by_rwlock(mplex_stream):
|
||||
"""Verify that stream.write() acquires and releases the write lock."""
|
||||
stream, _, muxed_conn = mplex_stream
|
||||
|
||||
# Mock lock methods
|
||||
original_acquire = stream.rw_lock.acquire_write
|
||||
original_release = stream.rw_lock.release_write
|
||||
|
||||
stream.rw_lock.acquire_write = AsyncMock(wraps=original_acquire)
|
||||
stream.rw_lock.release_write = MagicMock(wraps=original_release)
|
||||
|
||||
await stream.write(b"test data")
|
||||
|
||||
stream.rw_lock.acquire_write.assert_awaited_once()
|
||||
stream.rw_lock.release_write.assert_called_once()
|
||||
|
||||
# Verify the message was actually sent
|
||||
assert len(muxed_conn.sent_messages) == 1
|
||||
flag, data, stream_id = muxed_conn.sent_messages[0]
|
||||
assert flag == HeaderTags.MessageInitiator
|
||||
assert data == b"test data"
|
||||
assert stream_id == stream.stream_id
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stream_read_is_protected_by_rwlock(mplex_stream):
|
||||
"""Verify that stream.read() acquires and releases the read lock."""
|
||||
stream, send_chan, _ = mplex_stream
|
||||
|
||||
# Mock lock methods
|
||||
original_acquire = stream.rw_lock.acquire_read
|
||||
original_release = stream.rw_lock.release_read
|
||||
|
||||
stream.rw_lock.acquire_read = AsyncMock(wraps=original_acquire)
|
||||
stream.rw_lock.release_read = AsyncMock(wraps=original_release)
|
||||
|
||||
await send_chan.send(b"hello")
|
||||
result = await stream.read(5)
|
||||
|
||||
stream.rw_lock.acquire_read.assert_awaited_once()
|
||||
stream.rw_lock.release_read.assert_awaited_once()
|
||||
assert result == b"hello"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_readers_can_coexist(mplex_stream):
|
||||
"""Verify multiple readers can operate concurrently."""
|
||||
stream, send_chan, _ = mplex_stream
|
||||
|
||||
# Send enough data for both reads
|
||||
await send_chan.send(b"data1")
|
||||
await send_chan.send(b"data2")
|
||||
|
||||
# Track lock acquisition order
|
||||
acquisition_order = []
|
||||
release_order = []
|
||||
|
||||
# Patch lock methods to track concurrency
|
||||
original_acquire = stream.rw_lock.acquire_read
|
||||
original_release = stream.rw_lock.release_read
|
||||
|
||||
async def tracked_acquire():
|
||||
nonlocal acquisition_order
|
||||
acquisition_order.append("start")
|
||||
await original_acquire()
|
||||
acquisition_order.append("acquired")
|
||||
|
||||
async def tracked_release():
|
||||
nonlocal release_order
|
||||
release_order.append("start")
|
||||
await original_release()
|
||||
release_order.append("released")
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
stream.rw_lock, "acquire_read", side_effect=tracked_acquire, autospec=True
|
||||
),
|
||||
patch.object(
|
||||
stream.rw_lock, "release_read", side_effect=tracked_release, autospec=True
|
||||
),
|
||||
):
|
||||
# Execute concurrent reads
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(stream.read, 5)
|
||||
nursery.start_soon(stream.read, 5)
|
||||
|
||||
# Verify both reads happened
|
||||
assert acquisition_order.count("start") == 2
|
||||
assert acquisition_order.count("acquired") == 2
|
||||
assert release_order.count("start") == 2
|
||||
assert release_order.count("released") == 2
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_writer_blocks_readers(mplex_stream):
|
||||
"""Verify that a writer blocks all readers and new readers queue behind."""
|
||||
stream, send_chan, _ = mplex_stream
|
||||
|
||||
writer_acquired = trio.Event()
|
||||
readers_ready = trio.Event()
|
||||
writer_finished = trio.Event()
|
||||
all_readers_started = trio.Event()
|
||||
all_readers_done = trio.Event()
|
||||
|
||||
counters = {"reader_start_count": 0, "reader_done_count": 0}
|
||||
reader_target = 3
|
||||
reader_start_lock = trio.Lock()
|
||||
|
||||
# Patch write lock to control test flow
|
||||
original_acquire_write = stream.rw_lock.acquire_write
|
||||
original_release_write = stream.rw_lock.release_write
|
||||
|
||||
async def tracked_acquire_write():
|
||||
await original_acquire_write()
|
||||
writer_acquired.set()
|
||||
# Wait for readers to queue up
|
||||
await readers_ready.wait()
|
||||
|
||||
# Must be synchronous since real release_write is sync
|
||||
def tracked_release_write():
|
||||
original_release_write()
|
||||
writer_finished.set()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write
|
||||
),
|
||||
patch.object(
|
||||
stream.rw_lock, "release_write", side_effect=tracked_release_write
|
||||
),
|
||||
):
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start writer
|
||||
nursery.start_soon(stream.write, b"test")
|
||||
await writer_acquired.wait()
|
||||
|
||||
# Start readers
|
||||
async def reader_task():
|
||||
async with reader_start_lock:
|
||||
counters["reader_start_count"] += 1
|
||||
if counters["reader_start_count"] == reader_target:
|
||||
all_readers_started.set()
|
||||
|
||||
try:
|
||||
# This will block until data is available
|
||||
await stream.read(5)
|
||||
except (MplexStreamReset, MplexStreamEOF):
|
||||
pass
|
||||
finally:
|
||||
async with reader_start_lock:
|
||||
counters["reader_done_count"] += 1
|
||||
if counters["reader_done_count"] == reader_target:
|
||||
all_readers_done.set()
|
||||
|
||||
for _ in range(reader_target):
|
||||
nursery.start_soon(reader_task)
|
||||
|
||||
# Wait until all readers are started
|
||||
await all_readers_started.wait()
|
||||
|
||||
# Let the writer finish and release the lock
|
||||
readers_ready.set()
|
||||
await writer_finished.wait()
|
||||
|
||||
# Send data to unblock the readers
|
||||
for i in range(reader_target):
|
||||
await send_chan.send(b"data" + str(i).encode())
|
||||
|
||||
# Wait for all readers to finish
|
||||
await all_readers_done.wait()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_writer_waits_for_readers(mplex_stream):
|
||||
"""Verify a writer waits for existing readers to complete."""
|
||||
stream, send_chan, _ = mplex_stream
|
||||
readers_started = trio.Event()
|
||||
writer_entered = trio.Event()
|
||||
writer_acquiring = trio.Event()
|
||||
readers_finished = trio.Event()
|
||||
|
||||
# Send data for readers
|
||||
await send_chan.send(b"data1")
|
||||
await send_chan.send(b"data2")
|
||||
|
||||
# Patch read lock to control test flow
|
||||
original_acquire_read = stream.rw_lock.acquire_read
|
||||
|
||||
async def tracked_acquire_read():
|
||||
await original_acquire_read()
|
||||
readers_started.set()
|
||||
# Wait until readers are allowed to finish
|
||||
await readers_finished.wait()
|
||||
|
||||
# Patch write lock to detect when writer is blocked
|
||||
original_acquire_write = stream.rw_lock.acquire_write
|
||||
|
||||
async def tracked_acquire_write():
|
||||
writer_acquiring.set()
|
||||
await original_acquire_write()
|
||||
writer_entered.set()
|
||||
|
||||
with (
|
||||
patch.object(stream.rw_lock, "acquire_read", side_effect=tracked_acquire_read),
|
||||
patch.object(
|
||||
stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write
|
||||
),
|
||||
):
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start readers
|
||||
nursery.start_soon(stream.read, 5)
|
||||
nursery.start_soon(stream.read, 5)
|
||||
|
||||
# Wait for at least one reader to acquire the lock
|
||||
await readers_started.wait()
|
||||
|
||||
# Start writer (should block)
|
||||
nursery.start_soon(stream.write, b"test")
|
||||
|
||||
# Wait for writer to start acquiring lock
|
||||
await writer_acquiring.wait()
|
||||
|
||||
# Verify writer hasn't entered critical section
|
||||
assert not writer_entered.is_set()
|
||||
|
||||
# Allow readers to finish
|
||||
readers_finished.set()
|
||||
|
||||
# Verify writer can proceed
|
||||
await writer_entered.wait()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_lock_behavior_during_cancellation(mplex_stream):
|
||||
"""Verify that a lock is released when a task holding it is cancelled."""
|
||||
stream, _, _ = mplex_stream
|
||||
|
||||
reader_acquired_lock = trio.Event()
|
||||
|
||||
async def cancellable_reader(task_status):
|
||||
async with stream.rw_lock.read_lock():
|
||||
reader_acquired_lock.set()
|
||||
task_status.started()
|
||||
# Wait indefinitely until cancelled.
|
||||
await trio.sleep_forever()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start the reader and wait for it to acquire the lock.
|
||||
await nursery.start(cancellable_reader)
|
||||
await reader_acquired_lock.wait()
|
||||
|
||||
# Now that the reader has the lock, cancel the nursery.
|
||||
# This will cancel the reader task, and its lock should be released.
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# After the nursery is cancelled, the reader should have released the lock.
|
||||
# To verify, we try to acquire a write lock. If the read lock was not
|
||||
# released, this will time out.
|
||||
with trio.move_on_after(1) as cancel_scope:
|
||||
async with stream.rw_lock.write_lock():
|
||||
pass
|
||||
if cancel_scope.cancelled_caught:
|
||||
pytest.fail(
|
||||
"Write lock could not be acquired after a cancelled reader, "
|
||||
"indicating the read lock was not released."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_concurrent_read_write_sequence(mplex_stream):
|
||||
"""Verify complex sequence of interleaved reads and writes."""
|
||||
stream, send_chan, _ = mplex_stream
|
||||
results = []
|
||||
# Use a mock to intercept writes and feed them back to the read channel
|
||||
original_write = stream.write
|
||||
|
||||
reader1_finished = trio.Event()
|
||||
writer1_finished = trio.Event()
|
||||
reader2_finished = trio.Event()
|
||||
|
||||
async def mocked_write(data: bytes) -> None:
|
||||
await original_write(data)
|
||||
# Simulate the other side receiving the data and sending a response
|
||||
# by putting data into the read channel.
|
||||
await send_chan.send(data)
|
||||
|
||||
with patch.object(stream, "write", wraps=mocked_write) as patched_write:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Test scenario:
|
||||
# 1. Reader 1 starts, waits for data.
|
||||
# 2. Writer 1 writes, which gets fed back to the stream.
|
||||
# 3. Reader 2 starts, reads what Writer 1 wrote.
|
||||
# 4. Writer 2 writes.
|
||||
|
||||
async def reader1():
|
||||
nonlocal results
|
||||
results.append("R1 start")
|
||||
data = await stream.read(5)
|
||||
results.append(data)
|
||||
results.append("R1 done")
|
||||
reader1_finished.set()
|
||||
|
||||
async def writer1():
|
||||
nonlocal results
|
||||
await reader1_finished.wait()
|
||||
results.append("W1 start")
|
||||
await stream.write(b"write1")
|
||||
results.append("W1 done")
|
||||
writer1_finished.set()
|
||||
|
||||
async def reader2():
|
||||
nonlocal results
|
||||
await writer1_finished.wait()
|
||||
# This will read the data from writer1
|
||||
results.append("R2 start")
|
||||
data = await stream.read(6)
|
||||
results.append(data)
|
||||
results.append("R2 done")
|
||||
reader2_finished.set()
|
||||
|
||||
async def writer2():
|
||||
nonlocal results
|
||||
await reader2_finished.wait()
|
||||
results.append("W2 start")
|
||||
await stream.write(b"write2")
|
||||
results.append("W2 done")
|
||||
|
||||
# Execute sequence
|
||||
nursery.start_soon(reader1)
|
||||
nursery.start_soon(writer1)
|
||||
nursery.start_soon(reader2)
|
||||
nursery.start_soon(writer2)
|
||||
|
||||
await send_chan.send(b"data1")
|
||||
|
||||
# Verify sequence and that write was called
|
||||
assert patched_write.call_count == 2
|
||||
assert results == [
|
||||
"R1 start",
|
||||
b"data1",
|
||||
"R1 done",
|
||||
"W1 start",
|
||||
"W1 done",
|
||||
"R2 start",
|
||||
b"write1",
|
||||
"R2 done",
|
||||
"W2 start",
|
||||
"W2 done",
|
||||
]
|
||||
|
||||
|
||||
# ===============================================
|
||||
# 2. Tests for Reset, EOF, and Close Interactions
|
||||
# ===============================================
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_read_after_remote_close_triggers_eof(mplex_stream):
|
||||
"""Verify reading from a remotely closed stream returns EOF correctly."""
|
||||
stream, send_chan, _ = mplex_stream
|
||||
|
||||
# Send some data that can be read first
|
||||
await send_chan.send(b"data")
|
||||
# Close the channel to signify no more data will ever arrive
|
||||
await send_chan.aclose()
|
||||
|
||||
# Mark the stream as remotely closed
|
||||
stream.event_remote_closed.set()
|
||||
|
||||
# The first read should succeed, consuming the buffered data
|
||||
data = await stream.read(4)
|
||||
assert data == b"data"
|
||||
|
||||
# Now that the buffer is empty and the channel is closed, this should raise EOF
|
||||
with pytest.raises(MplexStreamEOF):
|
||||
await stream.read(1)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_read_on_closed_stream_raises_eof(mplex_stream):
|
||||
"""Test that reading from a closed stream with no data raises EOF."""
|
||||
stream, send_chan, _ = mplex_stream
|
||||
stream.event_remote_closed.set()
|
||||
await send_chan.aclose() # Ensure the channel is closed
|
||||
|
||||
# Reading from a stream that is closed and has no data should raise EOF
|
||||
with pytest.raises(MplexStreamEOF):
|
||||
await stream.read(100)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_write_to_locally_closed_stream_raises(mplex_stream):
|
||||
"""Verify writing to a locally closed stream raises MplexStreamClosed."""
|
||||
stream, _, _ = mplex_stream
|
||||
stream.event_local_closed.set()
|
||||
|
||||
with pytest.raises(MplexStreamClosed):
|
||||
await stream.write(b"this should fail")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_read_from_reset_stream_raises(mplex_stream):
|
||||
"""Verify reading from a reset stream raises MplexStreamReset."""
|
||||
stream, _, _ = mplex_stream
|
||||
stream.event_reset.set()
|
||||
|
||||
with pytest.raises(MplexStreamReset):
|
||||
await stream.read(10)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_write_to_reset_stream_raises(mplex_stream):
|
||||
"""Verify writing to a reset stream raises MplexStreamClosed."""
|
||||
stream, _, _ = mplex_stream
|
||||
# A stream reset implies it's also locally closed.
|
||||
await stream.reset()
|
||||
|
||||
# The `write` method checks `event_local_closed`, which `reset` sets.
|
||||
with pytest.raises(MplexStreamClosed):
|
||||
await stream.write(b"this should also fail")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stream_reset_cleans_up_resources(mplex_stream):
|
||||
"""Verify reset() cleans up stream state and resources."""
|
||||
stream, _, muxed_conn = mplex_stream
|
||||
stream_id = stream.stream_id
|
||||
|
||||
assert stream_id in muxed_conn.streams
|
||||
await stream.reset()
|
||||
|
||||
assert stream.event_reset.is_set()
|
||||
assert stream.event_local_closed.is_set()
|
||||
assert stream.event_remote_closed.is_set()
|
||||
assert (HeaderTags.ResetInitiator, None, stream_id) in muxed_conn.sent_messages
|
||||
assert stream_id not in muxed_conn.streams
|
||||
# Verify the underlying data channel is closed
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await stream.incoming_data_channel.receive()
|
||||
|
||||
|
||||
# ===============================================
|
||||
# 3. Rigorous Concurrency Tests with Events
|
||||
# ===============================================
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_writer_is_blocked_by_reader_using_events(mplex_stream):
|
||||
"""Verify a writer must wait for a reader using trio.Event for synchronization."""
|
||||
stream, _, _ = mplex_stream
|
||||
|
||||
reader_has_lock = trio.Event()
|
||||
writer_finished = trio.Event()
|
||||
|
||||
async def reader():
|
||||
async with stream.rw_lock.read_lock():
|
||||
reader_has_lock.set()
|
||||
# Hold the lock until the writer has finished its attempt
|
||||
await writer_finished.wait()
|
||||
|
||||
async def writer():
|
||||
await reader_has_lock.wait()
|
||||
# This call will now block until the reader releases the lock
|
||||
await stream.write(b"data")
|
||||
writer_finished.set()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(reader)
|
||||
nursery.start_soon(writer)
|
||||
|
||||
# Verify writer is blocked
|
||||
await wait_all_tasks_blocked()
|
||||
assert not writer_finished.is_set()
|
||||
|
||||
# Signal the reader to finish
|
||||
writer_finished.set()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_readers_can_read_concurrently_using_events(mplex_stream):
|
||||
"""Verify that multiple readers can acquire a read lock simultaneously."""
|
||||
stream, _, _ = mplex_stream
|
||||
|
||||
counters = {"readers_in_critical_section": 0}
|
||||
lock = trio.Lock() # To safely mutate the counter
|
||||
|
||||
reader1_acquired = trio.Event()
|
||||
reader2_acquired = trio.Event()
|
||||
all_readers_finished = trio.Event()
|
||||
|
||||
async def concurrent_reader(event_to_set: trio.Event):
|
||||
async with stream.rw_lock.read_lock():
|
||||
async with lock:
|
||||
counters["readers_in_critical_section"] += 1
|
||||
event_to_set.set()
|
||||
# Wait until all readers have finished before exiting the lock context
|
||||
await all_readers_finished.wait()
|
||||
async with lock:
|
||||
counters["readers_in_critical_section"] -= 1
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(concurrent_reader, reader1_acquired)
|
||||
nursery.start_soon(concurrent_reader, reader2_acquired)
|
||||
|
||||
# Wait for both readers to acquire their locks
|
||||
await reader1_acquired.wait()
|
||||
await reader2_acquired.wait()
|
||||
|
||||
# Check that both were in the critical section at the same time
|
||||
async with lock:
|
||||
assert counters["readers_in_critical_section"] == 2
|
||||
|
||||
# Signal for all readers to finish
|
||||
all_readers_finished.set()
|
||||
|
||||
# Verify they exit cleanly
|
||||
await wait_all_tasks_blocked()
|
||||
async with lock:
|
||||
assert counters["readers_in_critical_section"] == 0
|
||||
Reference in New Issue
Block a user