12 Commits

71 changed files with 1254 additions and 4404 deletions

View File

@ -1,13 +0,0 @@
libp2p.discovery.bootstrap package
==================================
Submodules
----------
Module contents
---------------
.. automodule:: libp2p.discovery.bootstrap
:members:
:undoc-members:
:show-inheritance:

View File

@ -7,7 +7,6 @@ Subpackages
.. toctree::
:maxdepth: 4
libp2p.discovery.bootstrap
libp2p.discovery.events
libp2p.discovery.mdns

View File

@ -1,136 +0,0 @@
import argparse
import logging
import secrets
import multiaddr
import trio
from libp2p import new_host
from libp2p.abc import PeerInfo
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.discovery.events.peerDiscovery import peerDiscovery
# Configure logging
logger = logging.getLogger("libp2p.discovery.bootstrap")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger.addHandler(handler)
# Configure root logger to only show warnings and above to reduce noise
# This prevents verbose DEBUG messages from multiaddr, DNS, etc.
logging.getLogger().setLevel(logging.WARNING)
# Specifically silence noisy libraries
logging.getLogger("multiaddr").setLevel(logging.WARNING)
logging.getLogger("root").setLevel(logging.WARNING)
def on_peer_discovery(peer_info: PeerInfo) -> None:
"""Handler for peer discovery events."""
logger.info(f"🔍 Discovered peer: {peer_info.peer_id}")
logger.debug(f" Addresses: {[str(addr) for addr in peer_info.addrs]}")
# Example bootstrap peers
BOOTSTRAP_PEERS = [
"/dnsaddr/github.com/p2p/QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN",
"/dnsaddr/cloudflare.com/p2p/QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN",
"/dnsaddr/google.com/p2p/QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN",
"/dnsaddr/bootstrap.libp2p.io/p2p/QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN",
"/dnsaddr/bootstrap.libp2p.io/p2p/QmbLHAnMoJPWSCR5Zhtx6BHJX9KiKNN6tpvbUcqanj75Nb",
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ",
"/ip6/2604:a880:1:20::203:d001/tcp/4001/p2p/QmSoLPppuBtQSGwKDZT2M73ULpjvfd3aZ6ha4oFGL1KrGM",
"/ip4/128.199.219.111/tcp/4001/p2p/QmSoLV4Bbm51jM9C4gDYZQ9Cy3U6aXMJDAbzgu2fzaDs64",
"/ip4/104.236.76.40/tcp/4001/p2p/QmSoLV4Bbm51jM9C4gDYZQ9Cy3U6aXMJDAbzgu2fzaDs64",
"/ip4/178.62.158.247/tcp/4001/p2p/QmSoLer265NRgSp2LA3dPaeykiS1J6DifTC88f5uVQKNAd",
"/ip6/2604:a880:1:20::203:d001/tcp/4001/p2p/QmSoLPppuBtQSGwKDZT2M73ULpjvfd3aZ6ha4oFGL1KrGM",
"/ip6/2400:6180:0:d0::151:6001/tcp/4001/p2p/QmSoLSafTMBsPKadTEgaXctDQVcqN88CNLHXMkTNwMKPnu",
"/ip6/2a03:b0c0:0:1010::23:1001/tcp/4001/p2p/QmSoLueR4xBeUbY9WZ9xGUUxunbKWcrNFTDAadQJmocnWm",
]
async def run(port: int, bootstrap_addrs: list[str]) -> None:
"""Run the bootstrap discovery example."""
# Generate key pair
secret = secrets.token_bytes(32)
key_pair = create_new_key_pair(secret)
# Create listen address
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
# Register peer discovery handler
peerDiscovery.register_peer_discovered_handler(on_peer_discovery)
logger.info("🚀 Starting Bootstrap Discovery Example")
logger.info(f"📍 Listening on: {listen_addr}")
logger.info(f"🌐 Bootstrap peers: {len(bootstrap_addrs)}")
print("\n" + "=" * 60)
print("Bootstrap Discovery Example")
print("=" * 60)
print("This example demonstrates connecting to bootstrap peers.")
print("Watch the logs for peer discovery events!")
print("Press Ctrl+C to exit.")
print("=" * 60)
# Create and run host with bootstrap discovery
host = new_host(key_pair=key_pair, bootstrap=bootstrap_addrs)
try:
async with host.run(listen_addrs=[listen_addr]):
# Keep running and log peer discovery events
await trio.sleep_forever()
except KeyboardInterrupt:
logger.info("👋 Shutting down...")
def main() -> None:
"""Main entry point."""
description = """
Bootstrap Discovery Example for py-libp2p
This example demonstrates how to use bootstrap peers for peer discovery.
Bootstrap peers are predefined peers that help new nodes join the network.
Usage:
python bootstrap.py -p 8000
python bootstrap.py -p 8001 --custom-bootstrap \\
"/ip4/127.0.0.1/tcp/8000/p2p/QmYourPeerID"
"""
parser = argparse.ArgumentParser(
description=description, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
"-p", "--port", default=0, type=int, help="Port to listen on (default: random)"
)
parser.add_argument(
"--custom-bootstrap",
nargs="*",
help="Custom bootstrap addresses (space-separated)",
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="Enable verbose output"
)
args = parser.parse_args()
if args.verbose:
logger.setLevel(logging.DEBUG)
# Use custom bootstrap addresses if provided, otherwise use defaults
bootstrap_addrs = (
args.custom_bootstrap if args.custom_bootstrap else BOOTSTRAP_PEERS
)
try:
trio.run(run, args.port, bootstrap_addrs)
except KeyboardInterrupt:
logger.info("Exiting...")
if __name__ == "__main__":
main()

View File

@ -43,9 +43,6 @@ async def run(port: int, destination: str) -> None:
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
host = new_host()
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
if not destination: # its the server
async def stream_handler(stream: INetStream) -> None:

View File

@ -45,10 +45,7 @@ async def run(port: int, destination: str, seed: int | None = None) -> None:
secret = secrets.token_bytes(32)
host = new_host(key_pair=create_new_key_pair(secret))
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
async with host.run(listen_addrs=[listen_addr]):
print(f"I am {host.get_id().to_string()}")
if not destination: # its the server

View File

@ -1,7 +1,6 @@
import argparse
import base64
import logging
import sys
import multiaddr
import trio
@ -14,8 +13,6 @@ from libp2p.identity.identify.identify import (
identify_handler_for,
parse_identify_response,
)
from libp2p.identity.identify.pb.identify_pb2 import Identify
from libp2p.peer.envelope import debug_dump_envelope, unmarshal_envelope
from libp2p.peer.peerinfo import (
info_from_p2p_addr,
)
@ -34,11 +31,10 @@ def decode_multiaddrs(raw_addrs):
return decoded_addrs
def print_identify_response(identify_response: Identify):
def print_identify_response(identify_response):
"""Pretty-print Identify response."""
public_key_b64 = base64.b64encode(identify_response.public_key).decode("utf-8")
listen_addrs = decode_multiaddrs(identify_response.listen_addrs)
signed_peer_record = unmarshal_envelope(identify_response.signedPeerRecord)
try:
observed_addr_decoded = decode_multiaddrs([identify_response.observed_addr])
except Exception:
@ -54,8 +50,6 @@ def print_identify_response(identify_response: Identify):
f" Agent Version: {identify_response.agent_version}"
)
debug_dump_envelope(signed_peer_record)
async def run(port: int, destination: str, use_varint_format: bool = True) -> None:
localhost_ip = "0.0.0.0"
@ -66,158 +60,58 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
host_a = new_host()
# Set up identify handler with specified format
# Set use_varint_format = False, if want to checkout the Signed-PeerRecord
identify_handler = identify_handler_for(
host_a, use_varint_format=use_varint_format
)
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, identify_handler)
async with (
host_a.run(listen_addrs=[listen_addr]),
trio.open_nursery() as nursery,
):
# Start the peer-store cleanup task
nursery.start_soon(host_a.get_peerstore().start_cleanup_task, 60)
async with host_a.run(listen_addrs=[listen_addr]):
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
# connections
server_addr = str(host_a.get_addrs()[0])
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
format_flag = "--raw-format" if not use_varint_format else ""
print(
f"First host listening (using {format_name} format). "
f"Run this from another console:\n\n"
f"identify-demo {format_flag} -d {client_addr}\n"
f"identify-demo "
f"-d {client_addr}\n"
)
print("Waiting for incoming identify request...")
# Add a custom handler to show connection events
async def custom_identify_handler(stream):
peer_id = stream.muxed_conn.peer_id
print(f"\n🔗 Received identify request from peer: {peer_id}")
# Show remote address in multiaddr format
try:
from libp2p.identity.identify.identify import (
_remote_address_to_multiaddr,
)
remote_address = stream.get_remote_address()
if remote_address:
observed_multiaddr = _remote_address_to_multiaddr(
remote_address
)
# Add the peer ID to create a complete multiaddr
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
print(f" Remote address: {complete_multiaddr}")
else:
print(f" Remote address: {remote_address}")
except Exception:
print(f" Remote address: {stream.get_remote_address()}")
# Call the original handler
await identify_handler(stream)
print(f"✅ Successfully processed identify request from {peer_id}")
# Replace the handler with our custom one
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler)
try:
await trio.sleep_forever()
except KeyboardInterrupt:
print("\n🛑 Shutting down listener...")
logger.info("Listener interrupted by user")
return
else:
# Create second host (dialer)
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
host_b = new_host()
async with (
host_b.run(listen_addrs=[listen_addr]),
trio.open_nursery() as nursery,
):
# Start the peer-store cleanup task
nursery.start_soon(host_b.get_peerstore().start_cleanup_task, 60)
async with host_b.run(listen_addrs=[listen_addr]):
# Connect to the first host
print(f"dialer (host_b) listening on {host_b.get_addrs()[0]}")
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
print(f"Second host connecting to peer: {info.peer_id}")
try:
await host_b.connect(info)
except Exception as e:
error_msg = str(e)
if "unable to connect" in error_msg or "SwarmException" in error_msg:
print(f"\n❌ Cannot connect to peer: {info.peer_id}")
print(f" Address: {destination}")
print(f" Error: {error_msg}")
print(
"\n💡 Make sure the peer is running and the address is correct."
)
return
else:
# Re-raise other exceptions
raise
stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,))
try:
print("Starting identify protocol...")
# Read the response using the utility function
from libp2p.utils.varint import read_length_prefixed_protobuf
response = await read_length_prefixed_protobuf(
stream, use_varint_format
)
full_response = response
# Read the complete response (could be either format)
# Read a larger chunk to get all the data before stream closes
response = await stream.read(8192) # Read enough data in one go
await stream.close()
# Parse the response using the robust protocol-level function
# This handles both old and new formats automatically
identify_msg = parse_identify_response(full_response)
identify_msg = parse_identify_response(response)
print_identify_response(identify_msg)
except Exception as e:
error_msg = str(e)
print(f"Identify protocol error: {error_msg}")
# Check for specific format mismatch errors
if "Error parsing message" in error_msg or "DecodeError" in error_msg:
print("\n" + "=" * 60)
print("FORMAT MISMATCH DETECTED!")
print("=" * 60)
if use_varint_format:
print(
"You are using length-prefixed format (default) but the "
"listener"
)
print("is using raw protobuf format.")
print(
"\nTo fix this, run the dialer with the --raw-format flag:"
)
print(f"identify-demo --raw-format -d {destination}")
else:
print("You are using raw protobuf format but the listener")
print("is using length-prefixed format (default).")
print(
"\nTo fix this, run the dialer without the --raw-format "
"flag:"
)
print(f"identify-demo -d {destination}")
print("=" * 60)
else:
import traceback
traceback.print_exc()
print(f"Identify protocol error: {e}")
return
@ -253,27 +147,16 @@ def main() -> None:
"length-prefixed (new format)"
),
)
args = parser.parse_args()
# Determine format: use varint (length-prefixed) if --raw-format is specified,
# otherwise use raw protobuf format (old format)
use_varint_format = args.raw_format
# Determine format: raw format if --raw-format is specified, otherwise
# length-prefixed
use_varint_format = not args.raw_format
try:
if args.destination:
# Run in dialer mode
trio.run(run, *(args.port, args.destination, use_varint_format))
else:
# Run in listener mode
trio.run(run, *(args.port, args.destination, use_varint_format))
except KeyboardInterrupt:
print("\n👋 Goodbye!")
logger.info("Application interrupted by user")
except Exception as e:
print(f"\n❌ Error: {str(e)}")
logger.error("Error: %s", str(e))
sys.exit(1)
pass
if __name__ == "__main__":

View File

@ -11,26 +11,23 @@ This example shows how to:
import logging
import multiaddr
import trio
from libp2p import (
new_host,
)
from libp2p.abc import (
INetStream,
)
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
from libp2p.identity.identify import (
identify_handler_for,
)
from libp2p.identity.identify_push import (
ID_PUSH,
identify_push_handler_for,
push_identify_to_peer,
)
from libp2p.peer.peerinfo import (
@ -41,145 +38,8 @@ from libp2p.peer.peerinfo import (
logger = logging.getLogger(__name__)
def create_custom_identify_handler(host, host_name: str):
"""Create a custom identify handler that displays received information."""
async def handle_identify(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
print(f"\n🔍 {host_name} received identify request from peer: {peer_id}")
# Get the standard identify response using the existing function
from libp2p.identity.identify.identify import (
_mk_identify_protobuf,
_remote_address_to_multiaddr,
)
# Get observed address
observed_multiaddr = None
try:
remote_address = stream.get_remote_address()
if remote_address:
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
except Exception:
pass
# Build the identify protobuf
identify_msg = _mk_identify_protobuf(host, observed_multiaddr)
response_data = identify_msg.SerializeToString()
print(f" 📋 {host_name} identify information:")
if identify_msg.HasField("protocol_version"):
print(f" Protocol Version: {identify_msg.protocol_version}")
if identify_msg.HasField("agent_version"):
print(f" Agent Version: {identify_msg.agent_version}")
if identify_msg.HasField("public_key"):
print(f" Public Key: {identify_msg.public_key.hex()[:16]}...")
if identify_msg.listen_addrs:
print(" Listen Addresses:")
for addr_bytes in identify_msg.listen_addrs:
addr = multiaddr.Multiaddr(addr_bytes)
print(f" - {addr}")
if identify_msg.protocols:
print(" Supported Protocols:")
for protocol in identify_msg.protocols:
print(f" - {protocol}")
# Send the response
await stream.write(response_data)
await stream.close()
return handle_identify
def create_custom_identify_push_handler(host, host_name: str):
"""Create a custom identify/push handler that displays received information."""
async def handle_identify_push(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
print(f"\n📤 {host_name} received identify/push from peer: {peer_id}")
try:
# Read the identify message using the utility function
from libp2p.utils.varint import read_length_prefixed_protobuf
data = await read_length_prefixed_protobuf(stream, use_varint_format=True)
# Parse the identify message
identify_msg = Identify()
identify_msg.ParseFromString(data)
print(" 📋 Received identify information:")
if identify_msg.HasField("protocol_version"):
print(f" Protocol Version: {identify_msg.protocol_version}")
if identify_msg.HasField("agent_version"):
print(f" Agent Version: {identify_msg.agent_version}")
if identify_msg.HasField("public_key"):
print(f" Public Key: {identify_msg.public_key.hex()[:16]}...")
if identify_msg.HasField("observed_addr") and identify_msg.observed_addr:
observed_addr = multiaddr.Multiaddr(identify_msg.observed_addr)
print(f" Observed Address: {observed_addr}")
if identify_msg.listen_addrs:
print(" Listen Addresses:")
for addr_bytes in identify_msg.listen_addrs:
addr = multiaddr.Multiaddr(addr_bytes)
print(f" - {addr}")
if identify_msg.protocols:
print(" Supported Protocols:")
for protocol in identify_msg.protocols:
print(f" - {protocol}")
# Update the peerstore with the new information
from libp2p.identity.identify_push.identify_push import (
_update_peerstore_from_identify,
)
await _update_peerstore_from_identify(
host.get_peerstore(), peer_id, identify_msg
)
print(f"{host_name} updated peerstore with new information")
except Exception as e:
print(f" ❌ Error processing identify/push: {e}")
finally:
await stream.close()
return handle_identify_push
async def display_peerstore_info(host, host_name: str, peer_id, description: str):
"""Display peerstore information for a specific peer."""
peerstore = host.get_peerstore()
try:
addrs = peerstore.addrs(peer_id)
except Exception:
addrs = []
try:
protocols = peerstore.get_protocols(peer_id)
except Exception:
protocols = []
print(f"\n📚 {host_name} peerstore for {description}:")
print(f" Peer ID: {peer_id}")
if addrs:
print(" Addresses:")
for addr in addrs:
print(f" - {addr}")
else:
print(" Addresses: None")
if protocols:
print(" Protocols:")
for protocol in protocols:
print(f" - {protocol}")
else:
print(" Protocols: None")
async def main() -> None:
print("\n==== Starting Enhanced Identify-Push Example ====\n")
print("\n==== Starting Identify-Push Example ====\n")
# Create key pairs for the two hosts
key_pair_1 = create_new_key_pair()
@ -188,57 +48,45 @@ async def main() -> None:
# Create the first host
host_1 = new_host(key_pair=key_pair_1)
# Set up custom identify and identify/push handlers
host_1.set_stream_handler(
TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_1, "Host 1")
)
host_1.set_stream_handler(
ID_PUSH, create_custom_identify_push_handler(host_1, "Host 1")
)
# Set up the identify and identify/push handlers
host_1.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_1))
host_1.set_stream_handler(ID_PUSH, identify_push_handler_for(host_1))
# Create the second host
host_2 = new_host(key_pair=key_pair_2)
# Set up custom identify and identify/push handlers
host_2.set_stream_handler(
TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_2, "Host 2")
)
host_2.set_stream_handler(
ID_PUSH, create_custom_identify_push_handler(host_2, "Host 2")
)
# Set up the identify and identify/push handlers
host_2.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_2))
host_2.set_stream_handler(ID_PUSH, identify_push_handler_for(host_2))
# Start listening on random ports using the run context manager
import multiaddr
listen_addr_1 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
listen_addr_2 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
async with (
host_1.run([listen_addr_1]),
host_2.run([listen_addr_2]),
trio.open_nursery() as nursery,
):
# Start the peer-store cleanup task
nursery.start_soon(host_1.get_peerstore().start_cleanup_task, 60)
nursery.start_soon(host_2.get_peerstore().start_cleanup_task, 60)
async with host_1.run([listen_addr_1]), host_2.run([listen_addr_2]):
# Get the addresses of both hosts
addr_1 = host_1.get_addrs()[0]
logger.info(f"Host 1 listening on {addr_1}")
print(f"Host 1 listening on {addr_1}")
print(f"Peer ID: {host_1.get_id().pretty()}")
addr_2 = host_2.get_addrs()[0]
logger.info(f"Host 2 listening on {addr_2}")
print(f"Host 2 listening on {addr_2}")
print(f"Peer ID: {host_2.get_id().pretty()}")
print("🏠 Host Configuration:")
print(f" Host 1: {addr_1}")
print(f" Host 1 Peer ID: {host_1.get_id().pretty()}")
print(f" Host 2: {addr_2}")
print(f" Host 2 Peer ID: {host_2.get_id().pretty()}")
print("\n🔗 Connecting Host 2 to Host 1...")
print("\nConnecting Host 2 to Host 1...")
# Connect host_2 to host_1
peer_info = info_from_p2p_addr(addr_1)
await host_2.connect(peer_info)
print("Host 2 successfully connected to Host 1")
logger.info("Host 2 connected to Host 1")
print("Host 2 successfully connected to Host 1")
# Run the identify protocol from host_2 to host_1
print("\n🔄 Running identify protocol (Host 2 → Host 1)...")
# (so Host 1 learns Host 2's address)
from libp2p.identity.identify.identify import ID as IDENTIFY_PROTOCOL_ID
stream = await host_2.new_stream(host_1.get_id(), (IDENTIFY_PROTOCOL_ID,))
@ -246,58 +94,64 @@ async def main() -> None:
await stream.close()
# Run the identify protocol from host_1 to host_2
print("\n🔄 Running identify protocol (Host 1 → Host 2)...")
# (so Host 2 learns Host 1's address)
stream = await host_1.new_stream(host_2.get_id(), (IDENTIFY_PROTOCOL_ID,))
response = await stream.read()
await stream.close()
# Update Host 1's peerstore with Host 2's addresses
# --- NEW CODE: Update Host 1's peerstore with Host 2's addresses ---
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
identify_msg = Identify()
identify_msg.ParseFromString(response)
peerstore_1 = host_1.get_peerstore()
peer_id_2 = host_2.get_id()
for addr_bytes in identify_msg.listen_addrs:
maddr = multiaddr.Multiaddr(addr_bytes)
peerstore_1.add_addr(peer_id_2, maddr, ttl=3600)
# TTL can be any positive int
peerstore_1.add_addr(
peer_id_2,
maddr,
ttl=3600,
)
# --- END NEW CODE ---
# Display peerstore information before push
await display_peerstore_info(
host_1, "Host 1", peer_id_2, "Host 2 (before push)"
# Now Host 1's peerstore should have Host 2's address
peerstore_1 = host_1.get_peerstore()
peer_id_2 = host_2.get_id()
addrs_1_for_2 = peerstore_1.addrs(peer_id_2)
logger.info(
f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: "
f"{addrs_1_for_2}"
)
print(
f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: "
f"{addrs_1_for_2}"
)
# Push identify information from host_1 to host_2
print("\n📤 Host 1 pushing identify information to Host 2...")
logger.info("Host 1 pushing identify information to Host 2")
print("\nHost 1 pushing identify information to Host 2...")
try:
# Call push_identify_to_peer which now returns a boolean
success = await push_identify_to_peer(host_1, host_2.get_id())
if success:
print("Identify push completed successfully!")
logger.info("Identify push completed successfully")
print("Identify push completed successfully!")
else:
print("⚠️ Identify push didn't complete successfully")
logger.warning("Identify push didn't complete successfully")
print("\nWarning: Identify push didn't complete successfully")
except Exception as e:
print(f"Error during identify push: {str(e)}")
logger.error(f"Error during identify push: {str(e)}")
print(f"\nError during identify push: {str(e)}")
# Give a moment for the identify/push processing to complete
await trio.sleep(0.5)
# Display peerstore information after push
await display_peerstore_info(host_1, "Host 1", peer_id_2, "Host 2 (after push)")
await display_peerstore_info(
host_2, "Host 2", host_1.get_id(), "Host 1 (after push)"
)
# Give more time for background tasks to finish and connections to stabilize
print("\n⏳ Waiting for background tasks to complete...")
await trio.sleep(1.0)
# Gracefully close connections to prevent connection errors
print("🔌 Closing connections...")
await host_2.disconnect(host_1.get_id())
await trio.sleep(0.2)
print("\n🎉 Example completed successfully!")
# Add this at the end of your async with block:
await trio.sleep(0.5) # Give background tasks time to finish
if __name__ == "__main__":

View File

@ -41,9 +41,6 @@ from libp2p.identity.identify import (
ID as ID_IDENTIFY,
identify_handler_for,
)
from libp2p.identity.identify.identify import (
_remote_address_to_multiaddr,
)
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
@ -75,30 +72,40 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
async def handle_identify_push(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
# Get remote address information
try:
remote_address = stream.get_remote_address()
if remote_address:
observed_multiaddr = _remote_address_to_multiaddr(remote_address)
logger.info(
"Connection from remote peer %s, address: %s, multiaddr: %s",
peer_id,
remote_address,
observed_multiaddr,
)
print(f"\n🔗 Received identify/push request from peer: {peer_id}")
# Add the peer ID to create a complete multiaddr
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
print(f" Remote address: {complete_multiaddr}")
except Exception as e:
logger.error("Error getting remote address: %s", e)
print(f"\n🔗 Received identify/push request from peer: {peer_id}")
if use_varint_format:
# Read length-prefixed identify message from the stream
from libp2p.utils.varint import decode_varint_from_bytes
try:
# Use the utility function to read the protobuf message
from libp2p.utils.varint import read_length_prefixed_protobuf
# First read the varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
data = await read_length_prefixed_protobuf(stream, use_varint_format)
if not length_bytes:
logger.warning("No length prefix received from peer %s", peer_id)
return
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
logger.warning("Incomplete message received from peer %s", peer_id)
return
else:
# Read raw protobuf message from the stream
data = b""
while True:
chunk = await stream.read(4096)
if not chunk:
break
data += chunk
identify_msg = Identify()
identify_msg.ParseFromString(data)
@ -148,41 +155,11 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
await _update_peerstore_from_identify(peerstore, peer_id, identify_msg)
logger.info("Successfully processed identify/push from peer %s", peer_id)
print(f"Successfully processed identify/push from peer {peer_id}")
print(f"\nSuccessfully processed identify/push from peer {peer_id}")
except Exception as e:
error_msg = str(e)
logger.error(
"Error processing identify/push from %s: %s", peer_id, error_msg
)
print(f"\nError processing identify/push from {peer_id}: {error_msg}")
# Check for specific format mismatch errors
if (
"Error parsing message" in error_msg
or "DecodeError" in error_msg
or "ParseFromString" in error_msg
):
print("\n" + "=" * 60)
print("FORMAT MISMATCH DETECTED!")
print("=" * 60)
if use_varint_format:
print(
"You are using length-prefixed format (default) but the "
"dialer is using raw protobuf format."
)
print("\nTo fix this, run the dialer with the --raw-format flag:")
print(
"identify-push-listener-dialer-demo --raw-format -d <ADDRESS>"
)
else:
print("You are using raw protobuf format but the dialer")
print("is using length-prefixed format (default).")
print(
"\nTo fix this, run the dialer without the --raw-format flag:"
)
print("identify-push-listener-dialer-demo -d <ADDRESS>")
print("=" * 60)
logger.error("Error processing identify/push from %s: %s", peer_id, e)
print(f"\nError processing identify/push from {peer_id}: {e}")
finally:
# Close the stream after processing
await stream.close()
@ -190,9 +167,7 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True):
return handle_identify_push
async def run_listener(
port: int, use_varint_format: bool = True, raw_format_flag: bool = False
) -> None:
async def run_listener(port: int, use_varint_format: bool = True) -> None:
"""Run a host in listener mode."""
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
print(
@ -212,13 +187,12 @@ async def run_listener(
)
host.set_stream_handler(
ID_IDENTIFY_PUSH,
custom_identify_push_handler_for(host, use_varint_format=use_varint_format),
identify_push_handler_for(host, use_varint_format=use_varint_format),
)
# Start listening
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
try:
async with host.run([listen_addr]):
addr = host.get_addrs()[0]
logger.info("Listener host ready!")
@ -231,22 +205,11 @@ async def run_listener(
print(f"Peer ID: {host.get_id().pretty()}")
print("\nRun dialer with command:")
if raw_format_flag:
print(f"identify-push-listener-dialer-demo -d {addr} --raw-format")
else:
print(f"identify-push-listener-dialer-demo -d {addr}")
print("\nWaiting for incoming identify/push requests... (Ctrl+C to exit)")
print("\nWaiting for incoming connections... (Ctrl+C to exit)")
# Keep running until interrupted
try:
await trio.sleep_forever()
except KeyboardInterrupt:
print("\n🛑 Shutting down listener...")
logger.info("Listener interrupted by user")
return
except Exception as e:
logger.error(f"Listener error: {e}")
raise
async def run_dialer(
@ -293,9 +256,7 @@ async def run_dialer(
try:
await host.connect(peer_info)
logger.info("Successfully connected to listener!")
print("Successfully connected to listener!")
print(f" Connected to: {peer_info.peer_id}")
print(f" Full address: {destination}")
print("Successfully connected to listener!")
# Push identify information to the listener
logger.info("Pushing identify information to listener...")
@ -309,7 +270,7 @@ async def run_dialer(
if success:
logger.info("Identify push completed successfully!")
print("Identify push completed successfully!")
print("Identify push completed successfully!")
logger.info("Example completed successfully!")
print("\nExample completed successfully!")
@ -320,56 +281,16 @@ async def run_dialer(
logger.warning("Example completed with warnings.")
print("Example completed with warnings.")
except Exception as e:
error_msg = str(e)
logger.error(f"Error during identify push: {error_msg}")
print(f"\nError during identify push: {error_msg}")
# Check for specific format mismatch errors
if (
"Error parsing message" in error_msg
or "DecodeError" in error_msg
or "ParseFromString" in error_msg
):
print("\n" + "=" * 60)
print("FORMAT MISMATCH DETECTED!")
print("=" * 60)
if use_varint_format:
print(
"You are using length-prefixed format (default) but the "
"listener is using raw protobuf format."
)
print(
"\nTo fix this, run the dialer with the --raw-format flag:"
)
print(
f"identify-push-listener-dialer-demo --raw-format -d "
f"{destination}"
)
else:
print("You are using raw protobuf format but the listener")
print("is using length-prefixed format (default).")
print(
"\nTo fix this, run the dialer without the --raw-format "
"flag:"
)
print(f"identify-push-listener-dialer-demo -d {destination}")
print("=" * 60)
logger.error(f"Error during identify push: {str(e)}")
print(f"\nError during identify push: {str(e)}")
logger.error("Example completed with errors.")
print("Example completed with errors.")
# Continue execution despite the push error
except Exception as e:
error_msg = str(e)
if "unable to connect" in error_msg or "SwarmException" in error_msg:
print(f"\n❌ Cannot connect to peer: {peer_info.peer_id}")
print(f" Address: {destination}")
print(f" Error: {error_msg}")
print("\n💡 Make sure the peer is running and the address is correct.")
return
else:
logger.error(f"Error during dialer operation: {error_msg}")
print(f"\nError during dialer operation: {error_msg}")
logger.error(f"Error during dialer operation: {str(e)}")
print(f"\nError during dialer operation: {str(e)}")
raise
@ -380,21 +301,12 @@ def main() -> None:
Without arguments, it runs as a listener on random port.
With -d parameter, it runs as a dialer on random port.
Port 0 (default) means the OS will automatically assign an available port.
This prevents port conflicts when running multiple instances.
Use --raw-format to send raw protobuf messages (old format) instead of
length-prefixed protobuf messages (new format, default).
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"-p",
"--port",
default=0,
type=int,
help="source port number (0 = random available port)",
)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument(
"-d",
"--destination",
@ -409,7 +321,6 @@ def main() -> None:
"length-prefixed (new format)"
),
)
args = parser.parse_args()
# Determine format: raw format if --raw-format is specified, otherwise
@ -422,12 +333,12 @@ def main() -> None:
trio.run(run_dialer, args.port, args.destination, use_varint_format)
else:
# Run in listener mode with random available port if not specified
trio.run(run_listener, args.port, use_varint_format, args.raw_format)
trio.run(run_listener, args.port, use_varint_format)
except KeyboardInterrupt:
print("\n👋 Goodbye!")
logger.info("Application interrupted by user")
print("\nInterrupted by user")
logger.info("Interrupted by user")
except Exception as e:
print(f"\nError: {str(e)}")
print(f"\nError: {str(e)}")
logger.error("Error: %s", str(e))
sys.exit(1)

View File

@ -151,10 +151,7 @@ async def run_node(
host = new_host(key_pair=key_pair)
listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}")
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
async with host.run(listen_addrs=[listen_addr]):
peer_id = host.get_id().pretty()
addr_str = f"/ip4/127.0.0.1/tcp/{port}/p2p/{peer_id}"
await connect_to_bootstrap_nodes(host, bootstrap_nodes)

View File

@ -46,10 +46,7 @@ async def run(port: int) -> None:
logger.info("Starting peer Discovery")
host = new_host(key_pair=key_pair, enable_mDNS=True)
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
async with host.run(listen_addrs=[listen_addr]):
await trio.sleep_forever()

View File

@ -59,9 +59,6 @@ async def run(port: int, destination: str) -> None:
host = new_host(listen_addrs=[listen_addr])
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
if not destination:
host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)

View File

@ -144,9 +144,6 @@ async def run(topic: str, destination: str | None, port: int | None) -> None:
pubsub = Pubsub(host, gossipsub)
termination_event = trio.Event() # Event to signal termination
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
logger.info(f"Node started with peer ID: {host.get_id()}")
logger.info(f"Listening on: {listen_addr}")
logger.info("Initializing PubSub and GossipSub...")

View File

@ -251,7 +251,6 @@ def new_host(
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
enable_mDNS: bool = False,
bootstrap: list[str] | None = None,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> IHost:
"""
@ -265,7 +264,6 @@ def new_host(
:param muxer_preference: optional explicit muxer preference
:param listen_addrs: optional list of multiaddrs to listen on
:param enable_mDNS: whether to enable mDNS discovery
:param bootstrap: optional list of bootstrap peer addresses as strings
:return: return a host instance
"""
swarm = new_swarm(
@ -278,7 +276,7 @@ def new_host(
)
if disc_opt is not None:
return RoutedHost(swarm, disc_opt, enable_mDNS, bootstrap)
return BasicHost(network=swarm,enable_mDNS=enable_mDNS , bootstrap=bootstrap, negotitate_timeout=negotiate_timeout)
return RoutedHost(swarm, disc_opt, enable_mDNS)
return BasicHost(network=swarm,enable_mDNS=enable_mDNS , negotitate_timeout=negotiate_timeout)
__version__ = __version("libp2p")

View File

@ -16,7 +16,6 @@ from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Optional,
)
from multiaddr import (
@ -42,19 +41,20 @@ from libp2p.io.abc import (
from libp2p.peer.id import (
ID,
)
import libp2p.peer.pb.peer_record_pb2 as pb
from libp2p.peer.peerinfo import (
PeerInfo,
)
if TYPE_CHECKING:
from libp2p.peer.envelope import Envelope
from libp2p.peer.peer_record import PeerRecord
from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.pubsub.pubsub import (
Pubsub,
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.pubsub.pb import (
rpc_pb2,
)
@ -493,71 +493,6 @@ class IAddrBook(ABC):
"""
# ------------------ certified-addr-book interface.py ---------------------
class ICertifiedAddrBook(ABC):
"""
Interface for a certified address book.
Provides methods for managing signed peer records
"""
@abstractmethod
def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool:
"""
Accept and store a signed PeerRecord, unless it's older than
the one already stored.
This function:
- Extracts the peer ID and sequence number from the envelope
- Rejects the record if it's older (lower seq)
- Updates the stored peer record and replaces associated
addresses if accepted
Parameters
----------
envelope:
Signed envelope containing a PeerRecord.
ttl:
Time-to-live for the included multiaddrs (in seconds).
"""
@abstractmethod
def get_peer_record(self, peer_id: ID) -> Optional["Envelope"]:
"""
Retrieve the most recent signed PeerRecord `Envelope` for a peer, if it exists
and is still relevant.
First, it runs cleanup via `maybe_delete_peer_record` to purge stale data.
Then it checks whether the peer has valid, unexpired addresses before
returning the associated envelope.
Parameters
----------
peer_id : ID
The peer to look up.
"""
@abstractmethod
def maybe_delete_peer_record(self, peer_id: ID) -> None:
"""
Delete the signed peer record for a peer if it has no know
(non-expired) addresses.
This is a garbage collection mechanism: if all addresses for a peer have expired
or been cleared, there's no point holding onto its signed `Envelope`
Parameters
----------
peer_id : ID
The peer whose record we may delete.
"""
# -------------------------- keybook interface.py --------------------------
@ -823,9 +758,7 @@ class IProtoBook(ABC):
# -------------------------- peerstore interface.py --------------------------
class IPeerStore(
IPeerMetadata, IAddrBook, ICertifiedAddrBook, IKeyBook, IMetrics, IProtoBook
):
class IPeerStore(IPeerMetadata, IAddrBook, IKeyBook, IMetrics, IProtoBook):
"""
Interface for a peer store.
@ -960,65 +893,7 @@ class IPeerStore(
"""
# --------CERTIFIED-ADDR-BOOK----------
@abstractmethod
def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool:
"""
Accept and store a signed PeerRecord, unless it's older
than the one already stored.
This function:
- Extracts the peer ID and sequence number from the envelope
- Rejects the record if it's older (lower seq)
- Updates the stored peer record and replaces associated addresses if accepted
Parameters
----------
envelope:
Signed envelope containing a PeerRecord.
ttl:
Time-to-live for the included multiaddrs (in seconds).
"""
@abstractmethod
def get_peer_record(self, peer_id: ID) -> Optional["Envelope"]:
"""
Retrieve the most recent signed PeerRecord `Envelope` for a peer, if it exists
and is still relevant.
First, it runs cleanup via `maybe_delete_peer_record` to purge stale data.
Then it checks whether the peer has valid, unexpired addresses before
returning the associated envelope.
Parameters
----------
peer_id : ID
The peer to look up.
"""
@abstractmethod
def maybe_delete_peer_record(self, peer_id: ID) -> None:
"""
Delete the signed peer record for a peer if it has no
know (non-expired) addresses.
This is a garbage collection mechanism: if all addresses for a peer have expired
or been cleared, there's no point holding onto its signed `Envelope`
Parameters
----------
peer_id : ID
The peer whose record we may delete.
"""
# --------KEY-BOOK----------
@abstractmethod
def pubkey(self, peer_id: ID) -> PublicKey:
"""
@ -1327,10 +1202,6 @@ class IPeerStore(
def clear_peerdata(self, peer_id: ID) -> None:
"""clear_peerdata"""
@abstractmethod
async def start_cleanup_task(self, cleanup_interval: int = 3600) -> None:
"""Start periodic cleanup of expired peer records and addresses."""
# -------------------------- listener interface.py --------------------------
@ -1818,121 +1689,6 @@ class IHost(ABC):
"""
# -------------------------- peer-record interface.py --------------------------
class IPeerRecord(ABC):
"""
Interface for a libp2p PeerRecord object.
A PeerRecord contains metadata about a peer such as its ID, public addresses,
and a strictly increasing sequence number for versioning.
PeerRecords are used in signed routing Envelopes for secure peer data propagation.
"""
@abstractmethod
def domain(self) -> str:
"""
Return the domain string for this record type.
Used in envelope validation to distinguish different record types.
"""
@abstractmethod
def codec(self) -> bytes:
"""
Return a binary codec prefix that identifies the PeerRecord type.
This is prepended in signed envelopes to allow type-safe decoding.
"""
@abstractmethod
def to_protobuf(self) -> pb.PeerRecord:
"""
Convert this PeerRecord into its Protobuf representation.
:raises ValueError: if serialization fails (e.g., invalid peer ID).
:return: A populated protobuf `PeerRecord` message.
"""
@abstractmethod
def marshal_record(self) -> bytes:
"""
Serialize this PeerRecord into a byte string.
Used when signing or sealing the record in an envelope.
:raises ValueError: if protobuf serialization fails.
:return: Byte-encoded PeerRecord.
"""
@abstractmethod
def equal(self, other: object) -> bool:
"""
Compare this PeerRecord with another for equality.
Two PeerRecords are considered equal if:
- They have the same `peer_id`
- Their `seq` numbers match
- Their address lists are identical and ordered
:param other: Object to compare with.
:return: True if equal, False otherwise.
"""
# -------------------------- envelope interface.py --------------------------
class IEnvelope(ABC):
@abstractmethod
def marshal_envelope(self) -> bytes:
"""
Serialize this Envelope into its protobuf wire format.
Converts all envelope fields into a `pb.Envelope` protobuf message
and returns the serialized bytes.
:return: Serialized envelope as bytes.
"""
@abstractmethod
def validate(self, domain: str) -> None:
"""
Verify the envelope's signature within the given domain scope.
This ensures that the envelope has not been tampered with
and was signed under the correct usage context.
:param domain: Domain string that contextualizes the signature.
:raises ValueError: If the signature is invalid.
"""
@abstractmethod
def record(self) -> "PeerRecord":
"""
Lazily decode and return the embedded PeerRecord.
This method unmarshals the payload bytes into a `PeerRecord` instance,
using the registered codec to identify the type. The decoded result
is cached for future use.
:return: Decoded PeerRecord object.
:raises Exception: If decoding fails or payload type is unsupported.
"""
@abstractmethod
def equal(self, other: Any) -> bool:
"""
Compare this Envelope with another for structural equality.
Two envelopes are considered equal if:
- They have the same public key
- The payload type and payload bytes match
- Their signatures are identical
:param other: Another object to compare.
:return: True if equal, False otherwise.
"""
# -------------------------- peerdata interface.py --------------------------

View File

@ -1,5 +0,0 @@
"""Bootstrap peer discovery module for py-libp2p."""
from .bootstrap import BootstrapDiscovery
__all__ = ["BootstrapDiscovery"]

View File

@ -1,94 +0,0 @@
import logging
from multiaddr import Multiaddr
from multiaddr.resolvers import DNSResolver
from libp2p.abc import ID, INetworkService, PeerInfo
from libp2p.discovery.bootstrap.utils import validate_bootstrap_addresses
from libp2p.discovery.events.peerDiscovery import peerDiscovery
from libp2p.peer.peerinfo import info_from_p2p_addr
logger = logging.getLogger("libp2p.discovery.bootstrap")
resolver = DNSResolver()
class BootstrapDiscovery:
"""
Bootstrap-based peer discovery for py-libp2p.
Connects to predefined bootstrap peers and adds them to peerstore.
"""
def __init__(self, swarm: INetworkService, bootstrap_addrs: list[str]):
self.swarm = swarm
self.peerstore = swarm.peerstore
self.bootstrap_addrs = bootstrap_addrs or []
self.discovered_peers: set[str] = set()
async def start(self) -> None:
"""Process bootstrap addresses and emit peer discovery events."""
logger.debug(
f"Starting bootstrap discovery with "
f"{len(self.bootstrap_addrs)} bootstrap addresses"
)
# Validate and filter bootstrap addresses
self.bootstrap_addrs = validate_bootstrap_addresses(self.bootstrap_addrs)
for addr_str in self.bootstrap_addrs:
try:
await self._process_bootstrap_addr(addr_str)
except Exception as e:
logger.debug(f"Failed to process bootstrap address {addr_str}: {e}")
def stop(self) -> None:
"""Clean up bootstrap discovery resources."""
logger.debug("Stopping bootstrap discovery")
self.discovered_peers.clear()
async def _process_bootstrap_addr(self, addr_str: str) -> None:
"""Convert string address to PeerInfo and add to peerstore."""
try:
multiaddr = Multiaddr(addr_str)
except Exception as e:
logger.debug(f"Invalid multiaddr format '{addr_str}': {e}")
return
if self.is_dns_addr(multiaddr):
resolved_addrs = await resolver.resolve(multiaddr)
peer_id_str = multiaddr.get_peer_id()
if peer_id_str is None:
logger.warning(f"Missing peer ID in DNS address: {addr_str}")
return
peer_id = ID.from_base58(peer_id_str)
addrs = [addr for addr in resolved_addrs]
if not addrs:
logger.warning(f"No addresses resolved for DNS address: {addr_str}")
return
peer_info = PeerInfo(peer_id, addrs)
self.add_addr(peer_info)
else:
self.add_addr(info_from_p2p_addr(multiaddr))
def is_dns_addr(self, addr: Multiaddr) -> bool:
"""Check if the address is a DNS address."""
return any(protocol.name == "dnsaddr" for protocol in addr.protocols())
def add_addr(self, peer_info: PeerInfo) -> None:
"""Add a peer to the peerstore and emit discovery event."""
# Skip if it's our own peer
if peer_info.peer_id == self.swarm.get_peer_id():
logger.debug(f"Skipping own peer ID: {peer_info.peer_id}")
return
# Always add addresses to peerstore (allows multiple addresses for same peer)
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
# Only emit discovery event if this is the first time we see this peer
peer_id_str = str(peer_info.peer_id)
if peer_id_str not in self.discovered_peers:
# Track discovered peer
self.discovered_peers.add(peer_id_str)
# Emit peer discovery event
peerDiscovery.emit_peer_discovered(peer_info)
logger.debug(f"Peer discovered: {peer_info.peer_id}")
else:
logger.debug(f"Additional addresses added for peer: {peer_info.peer_id}")

View File

@ -1,51 +0,0 @@
"""Utility functions for bootstrap discovery."""
import logging
from multiaddr import Multiaddr
from libp2p.peer.peerinfo import InvalidAddrError, PeerInfo, info_from_p2p_addr
logger = logging.getLogger("libp2p.discovery.bootstrap.utils")
def validate_bootstrap_addresses(addrs: list[str]) -> list[str]:
"""
Validate and filter bootstrap addresses.
:param addrs: List of bootstrap address strings
:return: List of valid bootstrap addresses
"""
valid_addrs = []
for addr_str in addrs:
try:
# Try to parse as multiaddr
multiaddr = Multiaddr(addr_str)
# Try to extract peer info (this validates the p2p component)
info_from_p2p_addr(multiaddr)
valid_addrs.append(addr_str)
logger.debug(f"Valid bootstrap address: {addr_str}")
except (InvalidAddrError, ValueError, Exception) as e:
logger.warning(f"Invalid bootstrap address '{addr_str}': {e}")
continue
return valid_addrs
def parse_bootstrap_peer_info(addr_str: str) -> PeerInfo | None:
"""
Parse bootstrap address string into PeerInfo.
:param addr_str: Bootstrap address string
:return: PeerInfo object or None if parsing fails
"""
try:
multiaddr = Multiaddr(addr_str)
return info_from_p2p_addr(multiaddr)
except Exception as e:
logger.error(f"Failed to parse bootstrap address '{addr_str}': {e}")
return None

View File

@ -29,7 +29,6 @@ from libp2p.custom_types import (
StreamHandlerFn,
TProtocol,
)
from libp2p.discovery.bootstrap.bootstrap import BootstrapDiscovery
from libp2p.discovery.mdns.mdns import MDNSDiscovery
from libp2p.host.defaults import (
get_default_protocols,
@ -93,7 +92,6 @@ class BasicHost(IHost):
self,
network: INetworkService,
enable_mDNS: bool = False,
bootstrap: list[str] | None = None,
default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None,
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> None:
@ -107,8 +105,6 @@ class BasicHost(IHost):
self.multiselect_client = MultiselectClient()
if enable_mDNS:
self.mDNS = MDNSDiscovery(network)
if bootstrap:
self.bootstrap = BootstrapDiscovery(network, bootstrap)
def get_id(self) -> ID:
"""
@ -176,16 +172,11 @@ class BasicHost(IHost):
if hasattr(self, "mDNS") and self.mDNS is not None:
logger.debug("Starting mDNS Discovery")
self.mDNS.start()
if hasattr(self, "bootstrap") and self.bootstrap is not None:
logger.debug("Starting Bootstrap Discovery")
await self.bootstrap.start()
try:
yield
finally:
if hasattr(self, "mDNS") and self.mDNS is not None:
self.mDNS.stop()
if hasattr(self, "bootstrap") and self.bootstrap is not None:
self.bootstrap.stop()
return _run()

View File

@ -19,13 +19,9 @@ class RoutedHost(BasicHost):
_router: IPeerRouting
def __init__(
self,
network: INetworkService,
router: IPeerRouting,
enable_mDNS: bool = False,
bootstrap: list[str] | None = None,
self, network: INetworkService, router: IPeerRouting, enable_mDNS: bool = False
):
super().__init__(network, enable_mDNS, bootstrap)
super().__init__(network, enable_mDNS)
self._router = router
async def connect(self, peer_info: PeerInfo) -> None:

View File

@ -15,8 +15,6 @@ from libp2p.custom_types import (
from libp2p.network.stream.exceptions import (
StreamClosed,
)
from libp2p.peer.envelope import seal_record
from libp2p.peer.peer_record import PeerRecord
from libp2p.utils import (
decode_varint_with_size,
get_agent_version,
@ -65,11 +63,6 @@ def _mk_identify_protobuf(
laddrs = host.get_addrs()
protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
# Create a signed peer-record for the remote peer
record = PeerRecord(host.get_id(), host.get_addrs())
envelope = seal_record(record, host.get_private_key())
protobuf = envelope.marshal_envelope()
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
return Identify(
protocol_version=PROTOCOL_VERSION,
@ -78,7 +71,6 @@ def _mk_identify_protobuf(
listen_addrs=map(_multiaddr_to_bytes, laddrs),
observed_addr=observed_addr,
protocols=protocols,
signedPeerRecord=protobuf,
)
@ -121,7 +113,7 @@ def parse_identify_response(response: bytes) -> Identify:
def identify_handler_for(
host: IHost, use_varint_format: bool = True
host: IHost, use_varint_format: bool = False
) -> StreamHandlerFn:
async def handle_identify(stream: INetStream) -> None:
# get observed address from ``stream``

View File

@ -9,5 +9,4 @@ message Identify {
repeated bytes listen_addrs = 2;
optional bytes observed_addr = 4;
repeated string protocols = 3;
optional bytes signedPeerRecord = 8;
}

View File

@ -1,12 +1,11 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/identity/identify/pb/identify.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@ -14,13 +13,13 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*libp2p/identity/identify/pb/identify.proto\x12\x0bidentify.pb\"\xa9\x01\n\x08Identify\x12\x18\n\x10protocol_version\x18\x05 \x01(\t\x12\x15\n\ragent_version\x18\x06 \x01(\t\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x14\n\x0clisten_addrs\x18\x02 \x03(\x0c\x12\x15\n\robserved_addr\x18\x04 \x01(\x0c\x12\x11\n\tprotocols\x18\x03 \x03(\t\x12\x18\n\x10signedPeerRecord\x18\x08 \x01(\x0c')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*libp2p/identity/identify/pb/identify.proto\x12\x0bidentify.pb\"\x8f\x01\n\x08Identify\x12\x18\n\x10protocol_version\x18\x05 \x01(\t\x12\x15\n\ragent_version\x18\x06 \x01(\t\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x14\n\x0clisten_addrs\x18\x02 \x03(\x0c\x12\x15\n\robserved_addr\x18\x04 \x01(\x0c\x12\x11\n\tprotocols\x18\x03 \x03(\t')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.identity.identify.pb.identify_pb2', _globals)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.identity.identify.pb.identify_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals['_IDENTIFY']._serialized_start=60
_globals['_IDENTIFY']._serialized_end=229
_IDENTIFY._serialized_start=60
_IDENTIFY._serialized_end=203
# @@protoc_insertion_point(module_scope)

View File

@ -1,24 +1,46 @@
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Optional as _Optional
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
DESCRIPTOR: _descriptor.FileDescriptor
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import typing
class Identify(_message.Message):
__slots__ = ("protocol_version", "agent_version", "public_key", "listen_addrs", "observed_addr", "protocols", "signedPeerRecord")
PROTOCOL_VERSION_FIELD_NUMBER: _ClassVar[int]
AGENT_VERSION_FIELD_NUMBER: _ClassVar[int]
PUBLIC_KEY_FIELD_NUMBER: _ClassVar[int]
LISTEN_ADDRS_FIELD_NUMBER: _ClassVar[int]
OBSERVED_ADDR_FIELD_NUMBER: _ClassVar[int]
PROTOCOLS_FIELD_NUMBER: _ClassVar[int]
SIGNEDPEERRECORD_FIELD_NUMBER: _ClassVar[int]
protocol_version: str
agent_version: str
public_key: bytes
listen_addrs: _containers.RepeatedScalarFieldContainer[bytes]
observed_addr: bytes
protocols: _containers.RepeatedScalarFieldContainer[str]
signedPeerRecord: bytes
def __init__(self, protocol_version: _Optional[str] = ..., agent_version: _Optional[str] = ..., public_key: _Optional[bytes] = ..., listen_addrs: _Optional[_Iterable[bytes]] = ..., observed_addr: _Optional[bytes] = ..., protocols: _Optional[_Iterable[str]] = ..., signedPeerRecord: _Optional[bytes] = ...) -> None: ...
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class Identify(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
PROTOCOL_VERSION_FIELD_NUMBER: builtins.int
AGENT_VERSION_FIELD_NUMBER: builtins.int
PUBLIC_KEY_FIELD_NUMBER: builtins.int
LISTEN_ADDRS_FIELD_NUMBER: builtins.int
OBSERVED_ADDR_FIELD_NUMBER: builtins.int
PROTOCOLS_FIELD_NUMBER: builtins.int
protocol_version: builtins.str
agent_version: builtins.str
public_key: builtins.bytes
observed_addr: builtins.bytes
@property
def listen_addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
@property
def protocols(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
def __init__(
self,
*,
protocol_version: builtins.str | None = ...,
agent_version: builtins.str | None = ...,
public_key: builtins.bytes | None = ...,
listen_addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
observed_addr: builtins.bytes | None = ...,
protocols: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["agent_version", b"agent_version", "observed_addr", b"observed_addr", "protocol_version", b"protocol_version", "public_key", b"public_key"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["agent_version", b"agent_version", "listen_addrs", b"listen_addrs", "observed_addr", b"observed_addr", "protocol_version", b"protocol_version", "protocols", b"protocols", "public_key", b"public_key"]) -> None: ...
global___Identify = Identify

View File

@ -20,7 +20,6 @@ from libp2p.custom_types import (
from libp2p.network.stream.exceptions import (
StreamClosed,
)
from libp2p.peer.envelope import consume_envelope
from libp2p.peer.id import (
ID,
)
@ -29,7 +28,7 @@ from libp2p.utils import (
varint,
)
from libp2p.utils.varint import (
read_length_prefixed_protobuf,
decode_varint_from_bytes,
)
from ..identify.identify import (
@ -67,8 +66,49 @@ def identify_push_handler_for(
peer_id = stream.muxed_conn.peer_id
try:
# Use the utility function to read the protobuf message
data = await read_length_prefixed_protobuf(stream, use_varint_format)
if use_varint_format:
# Read length-prefixed identify message from the stream
# First read the varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if not length_bytes:
logger.warning("No length prefix received from peer %s", peer_id)
return
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
logger.warning("Incomplete message received from peer %s", peer_id)
return
else:
# Read raw protobuf message from the stream
# For raw format, we need to read all data before the stream is closed
data = b""
try:
# Read all available data in a single operation
data = await stream.read()
except StreamClosed:
# Try to read any remaining data
try:
data = await stream.read()
except Exception:
pass
# If we got no data, log a warning and return
if not data:
logger.warning(
"No data received in raw format from peer %s", peer_id
)
return
identify_msg = Identify()
identify_msg.ParseFromString(data)
@ -79,11 +119,6 @@ def identify_push_handler_for(
)
logger.debug("Successfully processed identify/push from peer %s", peer_id)
# Send acknowledgment to indicate successful processing
# This ensures the sender knows the message was received before closing
await stream.write(b"OK")
except StreamClosed:
logger.debug(
"Stream closed while processing identify/push from %s", peer_id
@ -92,10 +127,7 @@ def identify_push_handler_for(
logger.error("Error processing identify/push from %s: %s", peer_id, e)
finally:
# Close the stream after processing
try:
await stream.close()
except Exception:
pass # Ignore errors when closing
return handle_identify_push
@ -141,19 +173,6 @@ async def _update_peerstore_from_identify(
except Exception as e:
logger.error("Error updating protocols for peer %s: %s", peer_id, e)
if identify_msg.HasField("signedPeerRecord"):
try:
# Convert the signed-peer-record(Envelope) from prtobuf bytes
envelope, _ = consume_envelope(
identify_msg.signedPeerRecord, "libp2p-peer-record"
)
# Use a default TTL of 2 hours (7200 seconds)
if not peerstore.consume_peer_record(envelope, 7200):
logger.error("Updating Certified-Addr-Book was unsuccessful")
except Exception as e:
logger.error(
"Error updating the certified addr book for peer %s: %s", peer_id, e
)
# Update observed address if present
if identify_msg.HasField("observed_addr") and identify_msg.observed_addr:
try:
@ -207,20 +226,7 @@ async def push_identify_to_peer(
# Send raw protobuf message
await stream.write(response)
# Wait for acknowledgment from the receiver with timeout
# This ensures the message was processed before closing
try:
with trio.move_on_after(1.0): # 1 second timeout
ack = await stream.read(2) # Read "OK" acknowledgment
if ack != b"OK":
logger.warning(
"Unexpected acknowledgment from peer %s: %s", peer_id, ack
)
except Exception as e:
logger.debug("No acknowledgment received from peer %s: %s", peer_id, e)
# Continue anyway, as the message might have been processed
# Close the stream after acknowledgment (or timeout)
# Close the stream
await stream.close()
logger.debug("Successfully pushed identify to peer %s", peer_id)

View File

@ -1,271 +0,0 @@
from typing import Any, cast
from libp2p.crypto.ed25519 import Ed25519PublicKey
from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.rsa import RSAPublicKey
from libp2p.crypto.secp256k1 import Secp256k1PublicKey
import libp2p.peer.pb.crypto_pb2 as cryto_pb
import libp2p.peer.pb.envelope_pb2 as pb
import libp2p.peer.pb.peer_record_pb2 as record_pb
from libp2p.peer.peer_record import (
PeerRecord,
peer_record_from_protobuf,
unmarshal_record,
)
from libp2p.utils.varint import encode_uvarint
ENVELOPE_DOMAIN = "libp2p-peer-record"
PEER_RECORD_CODEC = b"\x03\x01"
class Envelope:
"""
A signed wrapper around a serialized libp2p record.
Envelopes are cryptographically signed by the author's private key
and are scoped to a specific 'domain' to prevent cross-protocol replay.
Attributes:
public_key: The public key that can verify the envelope's signature.
payload_type: A multicodec code identifying the type of payload inside.
raw_payload: The raw serialized record data.
signature: Signature over the domain-scoped payload content.
"""
public_key: PublicKey
payload_type: bytes
raw_payload: bytes
signature: bytes
_cached_record: PeerRecord | None = None
_unmarshal_error: Exception | None = None
def __init__(
self,
public_key: PublicKey,
payload_type: bytes,
raw_payload: bytes,
signature: bytes,
):
self.public_key = public_key
self.payload_type = payload_type
self.raw_payload = raw_payload
self.signature = signature
def marshal_envelope(self) -> bytes:
"""
Serialize this Envelope into its protobuf wire format.
Converts all envelope fields into a `pb.Envelope` protobuf message
and returns the serialized bytes.
:return: Serialized envelope as bytes.
"""
pb_env = pb.Envelope(
public_key=pub_key_to_protobuf(self.public_key),
payload_type=self.payload_type,
payload=self.raw_payload,
signature=self.signature,
)
return pb_env.SerializeToString()
def validate(self, domain: str) -> None:
"""
Verify the envelope's signature within the given domain scope.
This ensures that the envelope has not been tampered with
and was signed under the correct usage context.
:param domain: Domain string that contextualizes the signature.
:raises ValueError: If the signature is invalid.
"""
unsigned = make_unsigned(domain, self.payload_type, self.raw_payload)
if not self.public_key.verify(unsigned, self.signature):
raise ValueError("Invalid envelope signature")
def record(self) -> PeerRecord:
"""
Lazily decode and return the embedded PeerRecord.
This method unmarshals the payload bytes into a `PeerRecord` instance,
using the registered codec to identify the type. The decoded result
is cached for future use.
:return: Decoded PeerRecord object.
:raises Exception: If decoding fails or payload type is unsupported.
"""
if self._cached_record is not None:
return self._cached_record
try:
if self.payload_type != PEER_RECORD_CODEC:
raise ValueError("Unsuported payload type in envelope")
msg = record_pb.PeerRecord()
msg.ParseFromString(self.raw_payload)
self._cached_record = peer_record_from_protobuf(msg)
return self._cached_record
except Exception as e:
self._unmarshal_error = e
raise
def equal(self, other: Any) -> bool:
"""
Compare this Envelope with another for structural equality.
Two envelopes are considered equal if:
- They have the same public key
- The payload type and payload bytes match
- Their signatures are identical
:param other: Another object to compare.
:return: True if equal, False otherwise.
"""
if isinstance(other, Envelope):
return (
self.public_key.__eq__(other.public_key)
and self.payload_type == other.payload_type
and self.signature == other.signature
and self.raw_payload == other.raw_payload
)
return False
def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey:
"""
Convert a Python PublicKey object to its protobuf equivalent.
:param pub_key: A libp2p-compatible PublicKey instance.
:return: Serialized protobuf PublicKey message.
"""
internal_key_type = pub_key.get_type()
key_type = cast(cryto_pb.KeyType, internal_key_type.value)
data = pub_key.to_bytes()
protobuf_key = cryto_pb.PublicKey(Type=key_type, Data=data)
return protobuf_key
def pub_key_from_protobuf(pb_key: cryto_pb.PublicKey) -> PublicKey:
"""
Parse a protobuf PublicKey message into a native libp2p PublicKey.
Supports Ed25519, RSA, and Secp256k1 key types.
:param pb_key: Protobuf representation of a public key.
:return: Parsed PublicKey object.
:raises ValueError: If the key type is unrecognized.
"""
if pb_key.Type == cryto_pb.KeyType.Ed25519:
return Ed25519PublicKey.from_bytes(pb_key.Data)
elif pb_key.Type == cryto_pb.KeyType.RSA:
return RSAPublicKey.from_bytes(pb_key.Data)
elif pb_key.Type == cryto_pb.KeyType.Secp256k1:
return Secp256k1PublicKey.from_bytes(pb_key.Data)
# libp2p.crypto.ecdsa not implemented
else:
raise ValueError(f"Unknown key type: {pb_key.Type}")
def seal_record(record: PeerRecord, private_key: PrivateKey) -> Envelope:
"""
Create and sign a new Envelope from a PeerRecord.
The record is serialized and signed in the scope of its domain and codec.
The result is a self-contained, verifiable Envelope.
:param record: A PeerRecord to encapsulate.
:param private_key: The signer's private key.
:return: A signed Envelope instance.
"""
payload = record.marshal_record()
unsigned = make_unsigned(record.domain(), record.codec(), payload)
signature = private_key.sign(unsigned)
return Envelope(
public_key=private_key.get_public_key(),
payload_type=record.codec(),
raw_payload=payload,
signature=signature,
)
def consume_envelope(data: bytes, domain: str) -> tuple[Envelope, PeerRecord]:
"""
Parse, validate, and decode an Envelope from bytes.
Validates the envelope's signature using the given domain and decodes
the inner payload into a PeerRecord.
:param data: Serialized envelope bytes.
:param domain: Domain string to verify signature against.
:return: Tuple of (Envelope, PeerRecord).
:raises ValueError: If signature validation or decoding fails.
"""
env = unmarshal_envelope(data)
env.validate(domain)
record = env.record()
return env, record
def unmarshal_envelope(data: bytes) -> Envelope:
"""
Deserialize an Envelope from its wire format.
This parses the protobuf fields without verifying the signature.
:param data: Serialized envelope bytes.
:return: Parsed Envelope object.
:raises DecodeError: If protobuf parsing fails.
"""
pb_env = pb.Envelope()
pb_env.ParseFromString(data)
pk = pub_key_from_protobuf(pb_env.public_key)
return Envelope(
public_key=pk,
payload_type=pb_env.payload_type,
raw_payload=pb_env.payload,
signature=pb_env.signature,
)
def make_unsigned(domain: str, payload_type: bytes, payload: bytes) -> bytes:
"""
Build a byte buffer to be signed for an Envelope.
The unsigned byte structure is:
varint(len(domain)) || domain ||
varint(len(payload_type)) || payload_type ||
varint(len(payload)) || payload
This is the exact input used during signing and verification.
:param domain: Domain string for signature scoping.
:param payload_type: Identifier for the type of payload.
:param payload: Raw serialized payload bytes.
:return: Byte buffer to be signed or verified.
"""
fields = [domain.encode(), payload_type, payload]
buf = bytearray()
for field in fields:
buf.extend(encode_uvarint(len(field)))
buf.extend(field)
return bytes(buf)
def debug_dump_envelope(env: Envelope) -> None:
print("\n=== Envelope ===")
print(f"Payload Type: {env.payload_type!r}")
print(f"Signature: {env.signature.hex()} ({len(env.signature)} bytes)")
print(f"Raw Payload: {env.raw_payload.hex()} ({len(env.raw_payload)} bytes)")
try:
peer_record = unmarshal_record(env.raw_payload)
print("\n=== Parsed PeerRecord ===")
print(peer_record)
except Exception as e:
print("Failed to parse PeerRecord:", e)

View File

@ -1,4 +1,3 @@
import functools
import hashlib
import base58
@ -37,23 +36,25 @@ if ENABLE_INLINING:
class ID:
_bytes: bytes
_xor_id: int | None = None
_b58_str: str | None = None
def __init__(self, peer_id_bytes: bytes) -> None:
self._bytes = peer_id_bytes
@functools.cached_property
@property
def xor_id(self) -> int:
return int(sha256_digest(self._bytes).hex(), 16)
@functools.cached_property
def base58(self) -> str:
return base58.b58encode(self._bytes).decode()
if not self._xor_id:
self._xor_id = int(sha256_digest(self._bytes).hex(), 16)
return self._xor_id
def to_bytes(self) -> bytes:
return self._bytes
def to_base58(self) -> str:
return self.base58
if not self._b58_str:
self._b58_str = base58.b58encode(self._bytes).decode()
return self._b58_str
def __repr__(self) -> str:
return f"<libp2p.peer.id.ID ({self!s})>"

View File

@ -1,22 +0,0 @@
syntax = "proto3";
package libp2p.peer.pb.crypto;
option go_package = "github.com/libp2p/go-libp2p/core/crypto/pb";
enum KeyType {
RSA = 0;
Ed25519 = 1;
Secp256k1 = 2;
ECDSA = 3;
}
message PublicKey {
KeyType Type = 1;
bytes Data = 2;
}
message PrivateKey {
KeyType Type = 1;
bytes Data = 2;
}

View File

@ -1,31 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/peer/pb/crypto.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1blibp2p/peer/pb/crypto.proto\x12\x15libp2p.peer.pb.crypto\"G\n\tPublicKey\x12,\n\x04Type\x18\x01 \x01(\x0e\x32\x1e.libp2p.peer.pb.crypto.KeyType\x12\x0c\n\x04\x44\x61ta\x18\x02 \x01(\x0c\"H\n\nPrivateKey\x12,\n\x04Type\x18\x01 \x01(\x0e\x32\x1e.libp2p.peer.pb.crypto.KeyType\x12\x0c\n\x04\x44\x61ta\x18\x02 \x01(\x0c*9\n\x07KeyType\x12\x07\n\x03RSA\x10\x00\x12\x0b\n\x07\x45\x64\x32\x35\x35\x31\x39\x10\x01\x12\r\n\tSecp256k1\x10\x02\x12\t\n\x05\x45\x43\x44SA\x10\x03\x42,Z*github.com/libp2p/go-libp2p/core/crypto/pbb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.peer.pb.crypto_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'Z*github.com/libp2p/go-libp2p/core/crypto/pb'
_globals['_KEYTYPE']._serialized_start=201
_globals['_KEYTYPE']._serialized_end=258
_globals['_PUBLICKEY']._serialized_start=54
_globals['_PUBLICKEY']._serialized_end=125
_globals['_PRIVATEKEY']._serialized_start=127
_globals['_PRIVATEKEY']._serialized_end=199
# @@protoc_insertion_point(module_scope)

View File

@ -1,33 +0,0 @@
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class KeyType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
RSA: _ClassVar[KeyType]
Ed25519: _ClassVar[KeyType]
Secp256k1: _ClassVar[KeyType]
ECDSA: _ClassVar[KeyType]
RSA: KeyType
Ed25519: KeyType
Secp256k1: KeyType
ECDSA: KeyType
class PublicKey(_message.Message):
__slots__ = ("Type", "Data")
TYPE_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
Type: KeyType
Data: bytes
def __init__(self, Type: _Optional[_Union[KeyType, str]] = ..., Data: _Optional[bytes] = ...) -> None: ...
class PrivateKey(_message.Message):
__slots__ = ("Type", "Data")
TYPE_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
Type: KeyType
Data: bytes
def __init__(self, Type: _Optional[_Union[KeyType, str]] = ..., Data: _Optional[bytes] = ...) -> None: ...

View File

@ -1,14 +0,0 @@
syntax = "proto3";
package libp2p.peer.pb.record;
import "libp2p/peer/pb/crypto.proto";
option go_package = "github.com/libp2p/go-libp2p/core/record/pb";
message Envelope {
libp2p.peer.pb.crypto.PublicKey public_key = 1;
bytes payload_type = 2;
bytes payload = 3;
bytes signature = 5;
}

View File

@ -1,28 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/peer/pb/envelope.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from libp2p.peer.pb import crypto_pb2 as libp2p_dot_peer_dot_pb_dot_crypto__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1dlibp2p/peer/pb/envelope.proto\x12\x15libp2p.peer.pb.record\x1a\x1blibp2p/peer/pb/crypto.proto\"z\n\x08\x45nvelope\x12\x34\n\npublic_key\x18\x01 \x01(\x0b\x32 .libp2p.peer.pb.crypto.PublicKey\x12\x14\n\x0cpayload_type\x18\x02 \x01(\x0c\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x42,Z*github.com/libp2p/go-libp2p/core/record/pbb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.peer.pb.envelope_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'Z*github.com/libp2p/go-libp2p/core/record/pb'
_globals['_ENVELOPE']._serialized_start=85
_globals['_ENVELOPE']._serialized_end=207
# @@protoc_insertion_point(module_scope)

View File

@ -1,18 +0,0 @@
from libp2p.peer.pb import crypto_pb2 as _crypto_pb2
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class Envelope(_message.Message):
__slots__ = ("public_key", "payload_type", "payload", "signature")
PUBLIC_KEY_FIELD_NUMBER: _ClassVar[int]
PAYLOAD_TYPE_FIELD_NUMBER: _ClassVar[int]
PAYLOAD_FIELD_NUMBER: _ClassVar[int]
SIGNATURE_FIELD_NUMBER: _ClassVar[int]
public_key: _crypto_pb2.PublicKey
payload_type: bytes
payload: bytes
signature: bytes
def __init__(self, public_key: _Optional[_Union[_crypto_pb2.PublicKey, _Mapping]] = ..., payload_type: _Optional[bytes] = ..., payload: _Optional[bytes] = ..., signature: _Optional[bytes] = ...) -> None: ... # type: ignore[type-arg]

View File

@ -1,31 +0,0 @@
syntax = "proto3";
package peer.pb;
option go_package = "github.com/libp2p/go-libp2p/core/peer/pb";
// PeerRecord messages contain information that is useful to share with other peers.
// Currently, a PeerRecord contains the public listen addresses for a peer, but this
// is expected to expand to include other information in the future.
//
// PeerRecords are designed to be serialized to bytes and placed inside of
// SignedEnvelopes before sharing with other peers.
// See https://github.com/libp2p/go-libp2p/blob/master/core/record/pb/envelope.proto for
// the SignedEnvelope definition.
message PeerRecord {
// AddressInfo is a wrapper around a binary multiaddr. It is defined as a
// separate message to allow us to add per-address metadata in the future.
message AddressInfo {
bytes multiaddr = 1;
}
// peer_id contains a libp2p peer id in its binary representation.
bytes peer_id = 1;
// seq contains a monotonically-increasing sequence counter to order PeerRecords in time.
uint64 seq = 2;
// addresses is a list of public listen addresses for the peer.
repeated AddressInfo addresses = 3;
}

View File

@ -1,29 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/peer/pb/peer_record.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/peer/pb/peer_record.proto\x12\x07peer.pb\"\x80\x01\n\nPeerRecord\x12\x0f\n\x07peer_id\x18\x01 \x01(\x0c\x12\x0b\n\x03seq\x18\x02 \x01(\x04\x12\x32\n\taddresses\x18\x03 \x03(\x0b\x32\x1f.peer.pb.PeerRecord.AddressInfo\x1a \n\x0b\x41\x64\x64ressInfo\x12\x11\n\tmultiaddr\x18\x01 \x01(\x0c\x42*Z(github.com/libp2p/go-libp2p/core/peer/pbb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.peer.pb.peer_record_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'Z(github.com/libp2p/go-libp2p/core/peer/pb'
_globals['_PEERRECORD']._serialized_start=46
_globals['_PEERRECORD']._serialized_end=174
_globals['_PEERRECORD_ADDRESSINFO']._serialized_start=142
_globals['_PEERRECORD_ADDRESSINFO']._serialized_end=174
# @@protoc_insertion_point(module_scope)

View File

@ -1,21 +0,0 @@
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class PeerRecord(_message.Message):
__slots__ = ("peer_id", "seq", "addresses")
class AddressInfo(_message.Message):
__slots__ = ("multiaddr",)
MULTIADDR_FIELD_NUMBER: _ClassVar[int]
multiaddr: bytes
def __init__(self, multiaddr: _Optional[bytes] = ...) -> None: ...
PEER_ID_FIELD_NUMBER: _ClassVar[int]
SEQ_FIELD_NUMBER: _ClassVar[int]
ADDRESSES_FIELD_NUMBER: _ClassVar[int]
peer_id: bytes
seq: int
addresses: _containers.RepeatedCompositeFieldContainer[PeerRecord.AddressInfo]
def __init__(self, peer_id: _Optional[bytes] = ..., seq: _Optional[int] = ..., addresses: _Optional[_Iterable[_Union[PeerRecord.AddressInfo, _Mapping]]] = ...) -> None: ... # type: ignore[type-arg]

View File

@ -1,251 +0,0 @@
from collections.abc import Sequence
import threading
import time
from typing import Any
from multiaddr import Multiaddr
from libp2p.abc import IPeerRecord
from libp2p.peer.id import ID
import libp2p.peer.pb.peer_record_pb2 as pb
from libp2p.peer.peerinfo import PeerInfo
PEER_RECORD_ENVELOPE_DOMAIN = "libp2p-peer-record"
PEER_RECORD_ENVELOPE_PAYLOAD_TYPE = b"\x03\x01"
_last_timestamp_lock = threading.Lock()
_last_timestamp: int = 0
class PeerRecord(IPeerRecord):
"""
A record that contains metatdata about a peer in the libp2p network.
This includes:
- `peer_id`: The peer's globally unique indentifier.
- `addrs`: A list of the peer's publicly reachable multiaddrs.
- `seq`: A strictly monotonically increasing timestamp used
to order records over time.
PeerRecords are designed to be signed and transmitted in libp2p routing Envelopes.
"""
peer_id: ID
addrs: list[Multiaddr]
seq: int
def __init__(
self,
peer_id: ID | None = None,
addrs: list[Multiaddr] | None = None,
seq: int | None = None,
) -> None:
"""
Initialize a new PeerRecord.
If `seq` is not provided, a timestamp-based strictly increasing sequence
number will be generated.
:param peer_id: ID of the peer this record refers to.
:param addrs: Public multiaddrs of the peer.
:param seq: Monotonic sequence number.
"""
if peer_id is not None:
self.peer_id = peer_id
self.addrs = addrs or []
if seq is not None:
self.seq = seq
else:
self.seq = timestamp_seq()
def __repr__(self) -> str:
return (
f"PeerRecord(\n"
f" peer_id={self.peer_id},\n"
f" multiaddrs={[str(m) for m in self.addrs]},\n"
f" seq={self.seq}\n"
f")"
)
def domain(self) -> str:
"""
Return the domain string associated with this PeerRecord.
Used during record signing and envelope validation to identify the record type.
"""
return PEER_RECORD_ENVELOPE_DOMAIN
def codec(self) -> bytes:
"""
Return the codec identifier for PeerRecords.
This binary perfix helps distinguish PeerRecords in serialized envelopes.
"""
return PEER_RECORD_ENVELOPE_PAYLOAD_TYPE
def to_protobuf(self) -> pb.PeerRecord:
"""
Convert the current PeerRecord into a ProtoBuf PeerRecord message.
:raises ValueError: if peer_id serialization fails.
:return: A ProtoBuf-encoded PeerRecord message object.
"""
try:
id_bytes = self.peer_id.to_bytes()
except Exception as e:
raise ValueError(f"failed to marshal peer_id: {e}")
msg = pb.PeerRecord()
msg.peer_id = id_bytes
msg.seq = self.seq
msg.addresses.extend(addrs_to_protobuf(self.addrs))
return msg
def marshal_record(self) -> bytes:
"""
Serialize a PeerRecord into raw bytes suitable for embedding in an Envelope.
This is typically called during the process of signing or sealing the record.
:raises ValueError: if serialization to protobuf fails.
:return: Serialized PeerRecord bytes.
"""
try:
msg = self.to_protobuf()
return msg.SerializeToString()
except Exception as e:
raise ValueError(f"failed to marshal PeerRecord: {e}")
def equal(self, other: Any) -> bool:
"""
Check if this PeerRecord is identical to another.
Two PeerRecords are considered equal if:
- Their peer IDs match.
- Their sequence numbers are identical.
- Their address lists are identical and in the same order.
:param other: Another PeerRecord instance.
:return: True if all fields mathch, False otherwise.
"""
if isinstance(other, PeerRecord):
if self.peer_id == other.peer_id:
if self.seq == other.seq:
if len(self.addrs) == len(other.addrs):
for a1, a2 in zip(self.addrs, other.addrs):
if a1 == a2:
continue
else:
return False
return True
return False
def unmarshal_record(data: bytes) -> PeerRecord:
"""
Deserialize a PeerRecord from its serialized byte representation.
Typically used when receiveing a PeerRecord inside a signed routing Envelope.
:param data: Serialized protobuf-encoded bytes.
:raises ValueError: if parsing or conversion fails.
:reurn: A valid PeerRecord instance.
"""
if data is None:
raise ValueError("cannot unmarshal PeerRecord from None")
msg = pb.PeerRecord()
try:
msg.ParseFromString(data)
except Exception as e:
raise ValueError(f"Failed to parse PeerRecord protobuf: {e}")
try:
record = peer_record_from_protobuf(msg)
except Exception as e:
raise ValueError(f"Failed to convert protobuf to PeerRecord: {e}")
return record
def timestamp_seq() -> int:
"""
Generate a strictly increasing timestamp-based sequence number.
Ensures that even if multiple PeerRecords are generated in the same nanosecond,
their `seq` values will still be strictly increasing by using a lock to track the
last value.
:return: A strictly increasing integer timestamp.
"""
global _last_timestamp
now = int(time.time_ns())
with _last_timestamp_lock:
if now <= _last_timestamp:
now = _last_timestamp + 1
_last_timestamp = now
return now
def peer_record_from_peer_info(info: PeerInfo) -> PeerRecord:
"""
Create a PeerRecord from a PeerInfo object.
This automatically assigns a timestamp-based sequence number to the record.
:param info: A PeerInfo instance (contains peer_id and addrs).
:return: A PeerRecord instance.
"""
record = PeerRecord()
record.peer_id = info.peer_id
record.addrs = info.addrs
return record
def peer_record_from_protobuf(msg: pb.PeerRecord) -> PeerRecord:
"""
Convert a protobuf PeerRecord message into a PeerRecord object.
:param msg: Protobuf PeerRecord message.
:raises ValueError: if the peer_id cannot be parsed.
:return: A deserialized PeerRecord instance.
"""
try:
peer_id = ID(msg.peer_id)
except Exception as e:
raise ValueError(f"Failed to unmarshal peer_id: {e}")
addrs = addrs_from_protobuf(msg.addresses)
seq = msg.seq
return PeerRecord(peer_id, addrs, seq)
def addrs_from_protobuf(addrs: Sequence[pb.PeerRecord.AddressInfo]) -> list[Multiaddr]:
"""
Convert a list of protobuf address records to Multiaddr objects.
:param addrs: A list of protobuf PeerRecord.AddressInfo messages.
:return: A list of decoded Multiaddr instances (invalid ones are skipped).
"""
out = []
for addr_info in addrs:
try:
addr = Multiaddr(addr_info.multiaddr)
out.append(addr)
except Exception:
continue
return out
def addrs_to_protobuf(addrs: list[Multiaddr]) -> list[pb.PeerRecord.AddressInfo]:
"""
Convert a list of Multiaddr objects into their protobuf representation.
:param addrs: A list of Multiaddr instances.
:return: A list of PeerRecord.AddressInfo protobuf messages.
"""
out = []
for addr in addrs:
addr_info = pb.PeerRecord.AddressInfo()
addr_info.multiaddr = addr.to_bytes()
out.append(addr_info)
return out

View File

@ -23,7 +23,6 @@ from libp2p.crypto.keys import (
PrivateKey,
PublicKey,
)
from libp2p.peer.envelope import Envelope
from .id import (
ID,
@ -39,25 +38,12 @@ from .peerinfo import (
PERMANENT_ADDR_TTL = 0
# TODO: Set up an async task for periodic peer-store cleanup
# for expired addresses and records.
class PeerRecordState:
envelope: Envelope
seq: int
def __init__(self, envelope: Envelope, seq: int):
self.envelope = envelope
self.seq = seq
class PeerStore(IPeerStore):
peer_data_map: dict[ID, PeerData]
def __init__(self, max_records: int = 10000) -> None:
def __init__(self) -> None:
self.peer_data_map = defaultdict(PeerData)
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
self.peer_record_map: dict[ID, PeerRecordState] = {}
self.max_records = max_records
def peer_info(self, peer_id: ID) -> PeerInfo:
"""
@ -84,10 +70,6 @@ class PeerStore(IPeerStore):
else:
raise PeerStoreError("peer ID not found")
# Clear the peer records
if peer_id in self.peer_record_map:
self.peer_record_map.pop(peer_id, None)
def valid_peer_ids(self) -> list[ID]:
"""
:return: all of the valid peer IDs stored in peer store
@ -100,38 +82,6 @@ class PeerStore(IPeerStore):
peer_data.clear_addrs()
return valid_peer_ids
def _enforce_record_limit(self) -> None:
"""Enforce maximum number of stored records."""
if len(self.peer_record_map) > self.max_records:
# Record oldest records based on seequence number
sorted_records = sorted(
self.peer_record_map.items(), key=lambda x: x[1].seq
)
records_to_remove = len(self.peer_record_map) - self.max_records
for peer_id, _ in sorted_records[:records_to_remove]:
self.maybe_delete_peer_record(peer_id)
del self.peer_record_map[peer_id]
async def start_cleanup_task(self, cleanup_interval: int = 3600) -> None:
"""Start periodic cleanup of expired peer records and addresses."""
while True:
await trio.sleep(cleanup_interval)
self._cleanup_expired_records()
def _cleanup_expired_records(self) -> None:
"""Remove expired peer records and addresses"""
expired_peers = []
for peer_id, peer_data in self.peer_data_map.items():
if peer_data.is_expired():
expired_peers.append(peer_id)
for peer_id in expired_peers:
self.maybe_delete_peer_record(peer_id)
del self.peer_data_map[peer_id]
self._enforce_record_limit()
# --------PROTO-BOOK--------
def get_protocols(self, peer_id: ID) -> list[str]:
@ -215,85 +165,6 @@ class PeerStore(IPeerStore):
peer_data = self.peer_data_map[peer_id]
peer_data.clear_metadata()
# -----CERT-ADDR-BOOK-----
# TODO: Make proper use of this function
def maybe_delete_peer_record(self, peer_id: ID) -> None:
"""
Delete the signed peer record for a peer if it has no know
(non-expired) addresses.
This is a garbage collection mechanism: if all addresses for a peer have expired
or been cleared, there's no point holding onto its signed `Envelope`
:param peer_id: The peer whose record we may delete/
"""
if peer_id in self.peer_record_map:
if not self.addrs(peer_id):
self.peer_record_map.pop(peer_id, None)
def consume_peer_record(self, envelope: Envelope, ttl: int) -> bool:
"""
Accept and store a signed PeerRecord, unless it's older than
the one already stored.
This function:
- Extracts the peer ID and sequence number from the envelope
- Rejects the record if it's older (lower seq)
- Updates the stored peer record and replaces associated addresses if accepted
:param envelope: Signed envelope containing a PeerRecord.
:param ttl: Time-to-live for the included multiaddrs (in seconds).
:return: True if the record was accepted and stored; False if it was rejected.
"""
record = envelope.record()
peer_id = record.peer_id
existing = self.peer_record_map.get(peer_id)
if existing and existing.seq > record.seq:
return False # reject older record
new_addrs = set(record.addrs)
self.peer_record_map[peer_id] = PeerRecordState(envelope, record.seq)
self.peer_data_map[peer_id].clear_addrs()
self.add_addrs(peer_id, list(new_addrs), ttl)
return True
def consume_peer_records(self, envelopes: list[Envelope], ttl: int) -> list[bool]:
"""Consume multiple peer records in a single operation."""
results = []
for envelope in envelopes:
results.append(self.consume_peer_record(envelope, ttl))
return results
def get_peer_record(self, peer_id: ID) -> Envelope | None:
"""
Retrieve the most recent signed PeerRecord `Envelope` for a peer, if it exists
and is still relevant.
First, it runs cleanup via `maybe_delete_peer_record` to purge stale data.
Then it checks whether the peer has valid, unexpired addresses before
returning the associated envelope.
:param peer_id: The peer to look up.
:return: The signed Envelope if the peer is known and has valid
addresses; None otherwise.
"""
self.maybe_delete_peer_record(peer_id)
# Check if the peer has any valid addresses
if (
peer_id in self.peer_data_map
and not self.peer_data_map[peer_id].is_expired()
):
state = self.peer_record_map.get(peer_id)
if state is not None:
return state.envelope
return None
# -------ADDR-BOOK--------
def add_addr(self, peer_id: ID, addr: Multiaddr, ttl: int = 0) -> None:
@ -322,8 +193,6 @@ class PeerStore(IPeerStore):
except trio.WouldBlock:
pass # Or consider logging / dropping / replacing stream
self.maybe_delete_peer_record(peer_id)
def addrs(self, peer_id: ID) -> list[Multiaddr]:
"""
:param peer_id: peer ID to get addrs for
@ -347,8 +216,6 @@ class PeerStore(IPeerStore):
if peer_id in self.peer_data_map:
self.peer_data_map[peer_id].clear_addrs()
self.maybe_delete_peer_record(peer_id)
def peers_with_addrs(self) -> list[ID]:
"""
:return: all of the peer IDs which has addrsfloat stored in peer store

View File

@ -11,10 +11,6 @@ import functools
import hashlib
import logging
import time
from typing import (
NamedTuple,
cast,
)
import base58
import trio
@ -30,8 +26,6 @@ from libp2p.crypto.keys import (
PrivateKey,
)
from libp2p.custom_types import (
AsyncValidatorFn,
SyncValidatorFn,
TProtocol,
ValidatorFn,
)
@ -77,6 +71,11 @@ from .pubsub_notifee import (
from .subscription import (
TrioSubscriptionAPI,
)
from .validation_throttler import (
TopicValidator,
ValidationResult,
ValidationThrottler,
)
from .validators import (
PUBSUB_SIGNING_PREFIX,
signature_validator,
@ -97,14 +96,6 @@ def get_content_addressed_msg_id(msg: rpc_pb2.Message) -> bytes:
return base64.b64encode(hashlib.sha256(msg.data).digest())
class TopicValidator(NamedTuple):
validator: ValidatorFn
is_async: bool
MAX_CONCURRENT_VALIDATORS = 10
class Pubsub(Service, IPubsub):
host: IHost
@ -112,7 +103,6 @@ class Pubsub(Service, IPubsub):
peer_receive_channel: trio.MemoryReceiveChannel[ID]
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
_validator_semaphore: trio.Semaphore
seen_messages: LastSeenCache
@ -147,7 +137,11 @@ class Pubsub(Service, IPubsub):
msg_id_constructor: Callable[
[rpc_pb2.Message], bytes
] = get_peer_and_seqno_msg_id,
max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS,
# TODO: these values have been copied from Go, but try to tune these dynamically
validation_queue_size: int = 32,
global_throttle_limit: int = 8192,
default_topic_throttle_limit: int = 1024,
validation_worker_count: int | None = None,
) -> None:
"""
Construct a new Pubsub object, which is responsible for handling all
@ -173,7 +167,6 @@ class Pubsub(Service, IPubsub):
# Therefore, we can only close from the receive side.
self.peer_receive_channel = peer_receive
self.dead_peer_receive_channel = dead_peer_receive
self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count)
# Register a notifee
self.host.get_network().register_notifee(
PubsubNotifee(peer_send, dead_peer_send)
@ -209,7 +202,15 @@ class Pubsub(Service, IPubsub):
# Create peers map, which maps peer_id (as string) to stream (to a given peer)
self.peers = {}
# Map of topic to topic validator
# Validation Throttler
self.validation_throttler = ValidationThrottler(
queue_size=validation_queue_size,
global_throttle_limit=global_throttle_limit,
default_topic_throttle_limit=default_topic_throttle_limit,
worker_count=validation_worker_count or 4,
)
# Keep a mapping of topic -> TopicValidator for easier lookup
self.topic_validators = {}
self.counter = int(time.time())
@ -221,10 +222,19 @@ class Pubsub(Service, IPubsub):
self.event_handle_dead_peer_queue_started = trio.Event()
async def run(self) -> None:
self.manager.run_daemon_task(self._start_validation_throttler)
self.manager.run_daemon_task(self.handle_peer_queue)
self.manager.run_daemon_task(self.handle_dead_peer_queue)
await self.manager.wait_finished()
async def _start_validation_throttler(self) -> None:
"""Start validation throttler in current nursery context"""
async with trio.open_nursery() as nursery:
await self.validation_throttler.start(nursery)
# Keep nursery alive until service stops
while self.manager.is_running:
await self.manager.wait_finished()
@property
def my_id(self) -> ID:
return self.host.get_id()
@ -304,7 +314,12 @@ class Pubsub(Service, IPubsub):
)
def set_topic_validator(
self, topic: str, validator: ValidatorFn, is_async_validator: bool
self,
topic: str,
validator: ValidatorFn,
is_async_validator: bool,
timeout: float | None = None,
throttle_limit: int | None = None,
) -> None:
"""
Register a validator under the given topic. One topic can only have one
@ -313,8 +328,18 @@ class Pubsub(Service, IPubsub):
:param topic: the topic to register validator under
:param validator: the validator used to validate messages published to the topic
:param is_async_validator: indicate if the validator is an asynchronous validator
:param timeout: optional timeout for the validator
:param throttle_limit: optional throttle limit for the validator
""" # noqa: E501
self.topic_validators[topic] = TopicValidator(validator, is_async_validator)
# Create throttled topic validator
topic_validator = self.validation_throttler.create_topic_validator(
topic=topic,
validator=validator,
is_async=is_async_validator,
timeout=timeout,
throttle_limit=throttle_limit,
)
self.topic_validators[topic] = topic_validator
def remove_topic_validator(self, topic: str) -> None:
"""
@ -324,17 +349,18 @@ class Pubsub(Service, IPubsub):
"""
self.topic_validators.pop(topic, None)
def get_msg_validators(self, msg: rpc_pb2.Message) -> tuple[TopicValidator, ...]:
def get_msg_validators(self, msg: rpc_pb2.Message) -> list[TopicValidator]:
"""
Get all validators corresponding to the topics in the message.
:param msg: the message published to the topic
:return: list of topic validators for the message's topics
"""
return tuple(
return [
self.topic_validators[topic]
for topic in msg.topicIDs
if topic in self.topic_validators
)
]
def add_to_blacklist(self, peer_id: ID) -> None:
"""
@ -663,60 +689,63 @@ class Pubsub(Service, IPubsub):
logger.debug("successfully published message %s", msg)
async def validate_msg(
self,
msg_forwarder: ID,
msg: rpc_pb2.Message,
) -> None:
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
"""
Validate the received message.
:param msg_forwarder: the peer who forward us the message.
:param msg: the message.
"""
sync_topic_validators: list[SyncValidatorFn] = []
async_topic_validators: list[AsyncValidatorFn] = []
for topic_validator in self.get_msg_validators(msg):
if topic_validator.is_async:
async_topic_validators.append(
cast(AsyncValidatorFn, topic_validator.validator)
)
else:
sync_topic_validators.append(
cast(SyncValidatorFn, topic_validator.validator)
)
# Get applicable validators for this message
validators = self.get_msg_validators(msg)
for validator in sync_topic_validators:
if not validator(msg_forwarder, msg):
raise ValidationError(f"Validation failed for msg={msg}")
if not validators:
# No validators, accept immediately
return
if len(async_topic_validators) > 0:
# Appends to lists are thread safe in CPython
results: list[bool] = []
# Use trio.Event for async coordination
validation_event = trio.Event()
result_container: dict[str, ValidationResult | None | Exception] = {
"result": None,
"error": None,
}
async with trio.open_nursery() as nursery:
for async_validator in async_topic_validators:
nursery.start_soon(
self._run_async_validator,
async_validator,
msg_forwarder,
msg,
results,
)
if not all(results):
raise ValidationError(f"Validation failed for msg={msg}")
async def _run_async_validator(
self,
func: AsyncValidatorFn,
msg_forwarder: ID,
msg: rpc_pb2.Message,
results: list[bool],
def handle_validation_result(
result: ValidationResult, error: Exception | None
) -> None:
async with self._validator_semaphore:
result = await func(msg_forwarder, msg)
results.append(result)
result_container["result"] = result
result_container["error"] = error
validation_event.set()
# Submit for throttled validation
success = await self.validation_throttler.submit_validation(
validators=validators,
msg_forwarder=msg_forwarder,
msg=msg,
result_callback=handle_validation_result,
)
if not success:
# Validation was throttled at queue level
raise ValidationError("Validation throttled at queue level")
# Wait for validation result
await validation_event.wait()
result = result_container["result"]
error = result_container["error"]
if error:
raise ValidationError(f"Validation error: {error}")
if result == ValidationResult.REJECT:
raise ValidationError("Message validation rejected")
elif result == ValidationResult.THROTTLED:
raise ValidationError("Message validation throttled")
elif result == ValidationResult.IGNORE:
# Treat IGNORE as rejection for now, or you could silently drop
raise ValidationError("Message validation ignored")
# ACCEPT case - just return normally
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
"""

View File

@ -0,0 +1,314 @@
from collections.abc import (
Callable,
)
from dataclasses import dataclass
from enum import Enum
import logging
from typing import (
NamedTuple,
cast,
)
import trio
from libp2p.custom_types import AsyncValidatorFn, ValidatorFn
from libp2p.peer.id import (
ID,
)
from .pb import (
rpc_pb2,
)
logger = logging.getLogger("libp2p.pubsub.validation")
class ValidationResult(Enum):
ACCEPT = "accept"
REJECT = "reject"
IGNORE = "ignore"
THROTTLED = "throttled"
@dataclass
class ValidationRequest:
"""Request for message validation"""
validators: list["TopicValidator"]
msg_forwarder: ID # peer ID
msg: rpc_pb2.Message # message object
result_callback: Callable[[ValidationResult, Exception | None], None]
class TopicValidator(NamedTuple):
topic: str
validator: ValidatorFn
is_async: bool
timeout: float | None = None
# Per-topic throttle semaphore
throttle_semaphore: trio.Semaphore | None = None
class ValidationThrottler:
"""Manages all validation throttling mechanisms"""
def __init__(
self,
queue_size: int = 32,
global_throttle_limit: int = 8192,
default_topic_throttle_limit: int = 1024,
worker_count: int | None = None,
):
# 1. Queue-level throttling - bounded memory channel
self._validation_send, self._validation_receive = trio.open_memory_channel[
ValidationRequest
](queue_size)
# 2. Global validation throttling - limits total concurrent async validations
self._global_throttle = trio.Semaphore(global_throttle_limit)
# 3. Per-topic throttling - each validator gets its own semaphore
self._default_topic_throttle_limit = default_topic_throttle_limit
# Worker management
# TODO: Find a better way to manage worker count
self._worker_count = worker_count or 4
self._running = False
async def start(self, nursery: trio.Nursery) -> None:
"""Start the validation workers"""
self._running = True
# Start validation worker tasks
for i in range(self._worker_count):
nursery.start_soon(self._validation_worker, f"worker-{i}")
async def stop(self) -> None:
"""Stop the validation system"""
self._running = False
await self._validation_send.aclose()
def create_topic_validator(
self,
topic: str,
validator: ValidatorFn,
is_async: bool,
timeout: float | None = None,
throttle_limit: int | None = None,
) -> TopicValidator:
"""Create a new topic validator with its own throttle"""
limit = throttle_limit or self._default_topic_throttle_limit
throttle_sem = trio.Semaphore(limit)
return TopicValidator(
topic=topic,
validator=validator,
is_async=is_async,
timeout=timeout,
throttle_semaphore=throttle_sem,
)
async def submit_validation(
self,
validators: list[TopicValidator],
msg_forwarder: ID,
msg: rpc_pb2.Message,
result_callback: Callable[[ValidationResult, Exception | None], None],
) -> bool:
"""
Submit a message for validation.
Returns True if queued successfully, False if queue is full (throttled).
"""
if not self._running:
result_callback(
ValidationResult.REJECT, Exception("Validation system not running")
)
return False
request = ValidationRequest(
validators=validators,
msg_forwarder=msg_forwarder,
msg=msg,
result_callback=result_callback,
)
try:
# This will raise trio.WouldBlock if queue is full
self._validation_send.send_nowait(request)
return True
except trio.WouldBlock:
# Queue-level throttling: drop the message
logger.debug(
"Validation queue full, dropping message from %s", msg_forwarder
)
result_callback(
ValidationResult.THROTTLED, Exception("Validation queue full")
)
return False
async def _validation_worker(self, worker_id: str) -> None:
"""Worker that processes validation requests"""
logger.debug("Validation worker %s started", worker_id)
async with self._validation_receive:
async for request in self._validation_receive:
if not self._running:
break
try:
# Process the validation request
result = await self._validate_message(request)
request.result_callback(result, None)
except Exception as e:
logger.exception("Error in validation worker %s", worker_id)
request.result_callback(ValidationResult.REJECT, e)
logger.debug("Validation worker %s stopped", worker_id)
async def _validate_message(self, request: ValidationRequest) -> ValidationResult:
"""Core validation logic with throttling"""
validators = request.validators
msg_forwarder = request.msg_forwarder
msg = request.msg
if not validators:
return ValidationResult.ACCEPT
# Separate sync and async validators
sync_validators = [v for v in validators if not v.is_async]
async_validators = [v for v in validators if v.is_async]
# Run synchronous validators first
for validator in sync_validators:
try:
# Apply per-topic throttling even for sync validators
if validator.throttle_semaphore:
validator.throttle_semaphore.acquire_nowait()
try:
result = validator.validator(msg_forwarder, msg)
if not result:
return ValidationResult.REJECT
finally:
validator.throttle_semaphore.release()
else:
result = validator.validator(msg_forwarder, msg)
if not result:
return ValidationResult.REJECT
except trio.WouldBlock:
# Per-topic throttling for sync validator
logger.debug("Sync validation throttled for topic %s", validator.topic)
return ValidationResult.THROTTLED
except Exception as e:
logger.exception(
"Sync validator failed for topic %s: %s", validator.topic, e
)
return ValidationResult.REJECT
# Handle async validators with global + per-topic throttling
if async_validators:
return await self._validate_async_validators(
async_validators, msg_forwarder, msg
)
return ValidationResult.ACCEPT
async def _validate_async_validators(
self, validators: list[TopicValidator], msg_forwarder: ID, msg: rpc_pb2.Message
) -> ValidationResult:
"""Handle async validators with proper throttling"""
if len(validators) == 1:
# Fast path for single validator
return await self._validate_single_async_validator(
validators[0], msg_forwarder, msg
)
# Multiple async validators - run them concurrently
try:
# Try to acquire global throttle slot
self._global_throttle.acquire_nowait()
except trio.WouldBlock:
logger.debug(
"Global validation throttle exceeded, dropping message from %s",
msg_forwarder,
)
return ValidationResult.THROTTLED
try:
async with trio.open_nursery() as nursery:
results = {}
async def run_validator(validator: TopicValidator, index: int) -> None:
"""Run a single async validator and store the result"""
nonlocal results
result = await self._validate_single_async_validator(
validator, msg_forwarder, msg
)
results[index] = result
# Start all validators concurrently
for i, validator in enumerate(validators):
nursery.start_soon(run_validator, validator, i)
# Process results - any reject or throttle causes overall failure
final_result = ValidationResult.ACCEPT
for result in results.values():
if result == ValidationResult.REJECT:
return ValidationResult.REJECT
elif result == ValidationResult.THROTTLED:
final_result = ValidationResult.THROTTLED
elif (
result == ValidationResult.IGNORE
and final_result == ValidationResult.ACCEPT
):
final_result = ValidationResult.IGNORE
return final_result
finally:
self._global_throttle.release()
return ValidationResult.IGNORE
async def _validate_single_async_validator(
self, validator: TopicValidator, msg_forwarder: ID, msg: rpc_pb2.Message
) -> ValidationResult:
"""Validate with a single async validator"""
# Apply per-topic throttling
if validator.throttle_semaphore:
try:
validator.throttle_semaphore.acquire_nowait()
except trio.WouldBlock:
logger.debug(
"Per-topic validation throttled for topic %s", validator.topic
)
return ValidationResult.THROTTLED
else:
# Fallback if no throttle semaphore configured
pass
try:
# Apply timeout if configured
result: bool
if validator.timeout:
with trio.fail_after(validator.timeout):
func = cast(AsyncValidatorFn, validator.validator)
result = await func(msg_forwarder, msg)
else:
func = cast(AsyncValidatorFn, validator.validator)
result = await func(msg_forwarder, msg)
return ValidationResult.ACCEPT if result else ValidationResult.REJECT
except trio.TooSlowError:
logger.debug("Validation timeout for topic %s", validator.topic)
return ValidationResult.IGNORE
except Exception as e:
logger.exception(
"Async validator failed for topic %s: %s", validator.topic, e
)
return ValidationResult.REJECT
finally:
if validator.throttle_semaphore:
validator.throttle_semaphore.release()
return ValidationResult.IGNORE

View File

@ -1,5 +1,3 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from types import (
TracebackType,
)
@ -34,72 +32,6 @@ if TYPE_CHECKING:
)
class ReadWriteLock:
"""
A read-write lock that allows multiple concurrent readers
or one exclusive writer, implemented using Trio primitives.
"""
def __init__(self) -> None:
self._readers = 0
self._readers_lock = trio.Lock() # Protects access to _readers count
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
async def acquire_read(self) -> None:
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
try:
async with self._readers_lock:
if self._readers == 0:
await self._writer_lock.acquire()
self._readers += 1
except trio.Cancelled:
raise
async def release_read(self) -> None:
"""Release a read lock."""
async with self._readers_lock:
if self._readers == 1:
self._writer_lock.release()
self._readers -= 1
async def acquire_write(self) -> None:
"""Acquire an exclusive write lock."""
try:
await self._writer_lock.acquire()
except trio.Cancelled:
raise
def release_write(self) -> None:
"""Release the exclusive write lock."""
self._writer_lock.release()
@asynccontextmanager
async def read_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a read lock safely."""
acquire = False
try:
await self.acquire_read()
acquire = True
yield
finally:
if acquire:
with trio.CancelScope() as scope:
scope.shield = True
await self.release_read()
@asynccontextmanager
async def write_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a write lock safely."""
acquire = False
try:
await self.acquire_write()
acquire = True
yield
finally:
if acquire:
self.release_write()
class MplexStream(IMuxedStream):
"""
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
@ -114,7 +46,7 @@ class MplexStream(IMuxedStream):
read_deadline: int | None
write_deadline: int | None
rw_lock: ReadWriteLock
# TODO: Add lock for read/write to avoid interleaving receiving messages?
close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation.
@ -148,7 +80,6 @@ class MplexStream(IMuxedStream):
self.event_remote_closed = trio.Event()
self.event_reset = trio.Event()
self.close_lock = trio.Lock()
self.rw_lock = ReadWriteLock()
self.incoming_data_channel = incoming_data_channel
self._buf = bytearray()
@ -182,7 +113,6 @@ class MplexStream(IMuxedStream):
:param n: number of bytes to read
:return: bytes actually read
"""
async with self.rw_lock.read_lock():
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read `n` must be non-negative or "
@ -194,8 +124,8 @@ class MplexStream(IMuxedStream):
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until
# there is no data, then return.
# Peek whether there is data available. If yes, we just read until there is
# no data, then return.
try:
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
@ -213,12 +143,12 @@ class MplexStream(IMuxedStream):
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when
# we are waiting for `receive`.
# Probably `incoming_data_channel` is closed in `reset` when we are
# waiting for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset."
"`incoming_data_channel` is closed but stream is not reset. "
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
@ -232,7 +162,6 @@ class MplexStream(IMuxedStream):
:return: number of bytes written
"""
async with self.rw_lock.write_lock():
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
@ -247,6 +176,8 @@ class MplexStream(IMuxedStream):
Closing a stream closes it for writing and closes the remote end for
reading but allows writing in the other direction.
"""
# TODO error handling with timeout
async with self.close_lock:
if self.event_local_closed.is_set():
return
@ -254,17 +185,8 @@ class MplexStream(IMuxedStream):
flag = (
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
)
try:
with trio.fail_after(5): # timeout in seconds
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
await self.muxed_conn.send_message(flag, None, self.stream_id)
except trio.TooSlowError:
raise TimeoutError("Timeout while trying to close the stream")
except MuxedConnUnavailable:
if not self.muxed_conn.event_shutting_down.is_set():
raise RuntimeError(
"Failed to send close message and Mplex isn't shutting down"
)
_is_remote_closed: bool
async with self.close_lock:

View File

@ -45,9 +45,6 @@ from libp2p.stream_muxer.exceptions import (
MuxedStreamReset,
)
# Configure logger for this module
logger = logging.getLogger("libp2p.stream_muxer.yamux")
PROTOCOL_ID = "/yamux/1.0.0"
TYPE_DATA = 0x0
TYPE_WINDOW_UPDATE = 0x1
@ -101,13 +98,13 @@ class YamuxStream(IMuxedStream):
# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock:
if self.send_window == 0:
logger.debug(
logging.debug(
f"Stream {self.stream_id}: Window is zero, waiting for update"
)
# Release lock and wait with timeout
@ -155,12 +152,12 @@ class YamuxStream(IMuxedStream):
"""
if increment <= 0:
# If increment is zero or negative, skip sending update
logger.debug(
logging.debug(
f"Stream {self.stream_id}: Skipping window update"
f"(increment={increment})"
)
return
logger.debug(
logging.debug(
f"Stream {self.stream_id}: Sending window update with increment={increment}"
)
@ -188,7 +185,7 @@ class YamuxStream(IMuxedStream):
# If the stream is closed for receiving and the buffer is empty, raise EOF
if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id):
logger.debug(
logging.debug(
f"Stream {self.stream_id}: Stream closed for receiving and buffer empty"
)
raise MuxedStreamEOF("Stream is closed for receiving")
@ -201,7 +198,7 @@ class YamuxStream(IMuxedStream):
# If buffer is not available, check if stream is closed
if buffer is None:
logger.debug(f"Stream {self.stream_id}: No buffer available")
logging.debug(f"Stream {self.stream_id}: No buffer available")
raise MuxedStreamEOF("Stream buffer closed")
# If we have data in buffer, process it
@ -213,34 +210,34 @@ class YamuxStream(IMuxedStream):
# Send window update for the chunk we just read
async with self.window_lock:
self.recv_window += len(chunk)
logger.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
await self.send_window_update(len(chunk), skip_lock=True)
# If stream is closed (FIN received) and buffer is empty, break
if self.recv_closed and len(buffer) == 0:
logger.debug(f"Stream {self.stream_id}: Closed with empty buffer")
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer")
break
# If stream was reset, raise reset error
if self.reset_received:
logger.debug(f"Stream {self.stream_id}: Stream was reset")
logging.debug(f"Stream {self.stream_id}: Stream was reset")
raise MuxedStreamReset("Stream was reset")
# Wait for more data or stream closure
logger.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
# After loop exit, first check if we have data to return
if data:
logger.debug(
logging.debug(
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
)
return data
# No data accumulated, now check why we exited the loop
if self.conn.event_shutting_down.is_set():
logger.debug(f"Stream {self.stream_id}: Connection shutting down")
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
raise MuxedStreamEOF("Connection shut down")
# Return empty data
@ -249,7 +246,7 @@ class YamuxStream(IMuxedStream):
data = await self.conn.read_stream(self.stream_id, n)
async with self.window_lock:
self.recv_window += len(data)
logger.debug(
logging.debug(
f"Stream {self.stream_id}: Sending window update after read, "
f"increment={len(data)}"
)
@ -258,7 +255,7 @@ class YamuxStream(IMuxedStream):
async def close(self) -> None:
if not self.send_closed:
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
logging.debug(f"Half-closing stream {self.stream_id} (local end)")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
)
@ -274,7 +271,7 @@ class YamuxStream(IMuxedStream):
async def reset(self) -> None:
if not self.closed:
logger.debug(f"Resetting stream {self.stream_id}")
logging.debug(f"Resetting stream {self.stream_id}")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
)
@ -352,7 +349,7 @@ class Yamux(IMuxedConn):
self._nursery: Nursery | None = None
async def start(self) -> None:
logger.debug(f"Starting Yamux for {self.peer_id}")
logging.debug(f"Starting Yamux for {self.peer_id}")
if self.event_started.is_set():
return
async with trio.open_nursery() as nursery:
@ -365,7 +362,7 @@ class Yamux(IMuxedConn):
return self.is_initiator_value
async def close(self, error_code: int = GO_AWAY_NORMAL) -> None:
logger.debug(f"Closing Yamux connection with code {error_code}")
logging.debug(f"Closing Yamux connection with code {error_code}")
async with self.streams_lock:
if not self.event_shutting_down.is_set():
try:
@ -374,7 +371,7 @@ class Yamux(IMuxedConn):
)
await self.secured_conn.write(header)
except Exception as e:
logger.debug(f"Failed to send GO_AWAY: {e}")
logging.debug(f"Failed to send GO_AWAY: {e}")
self.event_shutting_down.set()
for stream in self.streams.values():
stream.closed = True
@ -385,12 +382,12 @@ class Yamux(IMuxedConn):
self.stream_events.clear()
try:
await self.secured_conn.close()
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
except Exception as e:
logger.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
logging.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
self.event_closed.set()
if self.on_close:
logger.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
logging.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
if inspect.iscoroutinefunction(self.on_close):
if self.on_close is not None:
await self.on_close()
@ -419,7 +416,7 @@ class Yamux(IMuxedConn):
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0
)
logger.debug(f"Sending SYN header for stream {stream_id}")
logging.debug(f"Sending SYN header for stream {stream_id}")
await self.secured_conn.write(header)
return stream
except Exception as e:
@ -427,32 +424,32 @@ class Yamux(IMuxedConn):
raise e
async def accept_stream(self) -> IMuxedStream:
logger.debug("Waiting for new stream")
logging.debug("Waiting for new stream")
try:
stream = await self.new_stream_receive_channel.receive()
logger.debug(f"Received stream {stream.stream_id}")
logging.debug(f"Received stream {stream.stream_id}")
return stream
except trio.EndOfChannel:
raise MuxedStreamError("No new streams available")
async def read_stream(self, stream_id: int, n: int = -1) -> bytes:
logger.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
logging.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
if n is None:
n = -1
while True:
async with self.streams_lock:
if stream_id not in self.streams:
logger.debug(f"Stream {self.peer_id}:{stream_id} unknown")
logging.debug(f"Stream {self.peer_id}:{stream_id} unknown")
raise MuxedStreamEOF("Stream closed")
if self.event_shutting_down.is_set():
logger.debug(
logging.debug(
f"Stream {self.peer_id}:{stream_id}: connection shutting down"
)
raise MuxedStreamEOF("Connection shut down")
stream = self.streams[stream_id]
buffer = self.stream_buffers.get(stream_id)
logger.debug(
logging.debug(
f"Stream {self.peer_id}:{stream_id}: "
f"closed={stream.closed}, "
f"recv_closed={stream.recv_closed}, "
@ -460,7 +457,7 @@ class Yamux(IMuxedConn):
f"buffer_len={len(buffer) if buffer else 0}"
)
if buffer is None:
logger.debug(
logging.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"Buffer gone, assuming closed"
)
@ -473,7 +470,7 @@ class Yamux(IMuxedConn):
else:
data = bytes(buffer[:n])
del buffer[:n]
logger.debug(
logging.debug(
f"Returning {len(data)} bytes"
f"from stream {self.peer_id}:{stream_id}, "
f"buffer_len={len(buffer)}"
@ -481,7 +478,7 @@ class Yamux(IMuxedConn):
return data
# If reset received and buffer is empty, raise reset
if stream.reset_received:
logger.debug(
logging.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"reset_received=True, raising MuxedStreamReset"
)
@ -494,7 +491,7 @@ class Yamux(IMuxedConn):
else:
data = bytes(buffer[:n])
del buffer[:n]
logger.debug(
logging.debug(
f"Returning {len(data)} bytes"
f"from stream {self.peer_id}:{stream_id}, "
f"buffer_len={len(buffer)}"
@ -502,21 +499,21 @@ class Yamux(IMuxedConn):
return data
# Check if stream is closed
if stream.closed:
logger.debug(
logging.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"closed=True, raising MuxedStreamReset"
)
raise MuxedStreamReset("Stream is reset or closed")
# Check if recv_closed and buffer empty
if stream.recv_closed:
logger.debug(
logging.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"recv_closed=True, buffer empty, raising EOF"
)
raise MuxedStreamEOF("Stream is closed for receiving")
# Wait for data if stream is still open
logger.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
try:
await self.stream_events[stream_id].wait()
self.stream_events[stream_id] = trio.Event()
@ -531,7 +528,7 @@ class Yamux(IMuxedConn):
try:
header = await self.secured_conn.read(HEADER_SIZE)
if not header or len(header) < HEADER_SIZE:
logger.debug(
logging.debug(
f"Connection closed orincomplete header for peer {self.peer_id}"
)
self.event_shutting_down.set()
@ -540,7 +537,7 @@ class Yamux(IMuxedConn):
version, typ, flags, stream_id, length = struct.unpack(
YAMUX_HEADER_FORMAT, header
)
logger.debug(
logging.debug(
f"Received header for peer {self.peer_id}:"
f"type={typ}, flags={flags}, stream_id={stream_id},"
f"length={length}"
@ -561,7 +558,7 @@ class Yamux(IMuxedConn):
0,
)
await self.secured_conn.write(ack_header)
logger.debug(
logging.debug(
f"Sending stream {stream_id}"
f"to channel for peer {self.peer_id}"
)
@ -579,7 +576,7 @@ class Yamux(IMuxedConn):
elif typ == TYPE_DATA and flags & FLAG_RST:
async with self.streams_lock:
if stream_id in self.streams:
logger.debug(
logging.debug(
f"Resetting stream {stream_id} for peer {self.peer_id}"
)
self.streams[stream_id].closed = True
@ -588,27 +585,27 @@ class Yamux(IMuxedConn):
elif typ == TYPE_DATA and flags & FLAG_ACK:
async with self.streams_lock:
if stream_id in self.streams:
logger.debug(
logging.debug(
f"Received ACK for stream"
f"{stream_id} for peer {self.peer_id}"
)
elif typ == TYPE_GO_AWAY:
error_code = length
if error_code == GO_AWAY_NORMAL:
logger.debug(
logging.debug(
f"Received GO_AWAY for peer"
f"{self.peer_id}: Normal termination"
)
elif error_code == GO_AWAY_PROTOCOL_ERROR:
logger.error(
logging.error(
f"Received GO_AWAY for peer{self.peer_id}: Protocol error"
)
elif error_code == GO_AWAY_INTERNAL_ERROR:
logger.error(
logging.error(
f"Received GO_AWAY for peer {self.peer_id}: Internal error"
)
else:
logger.error(
logging.error(
f"Received GO_AWAY for peer {self.peer_id}"
f"with unknown error code: {error_code}"
)
@ -617,7 +614,7 @@ class Yamux(IMuxedConn):
break
elif typ == TYPE_PING:
if flags & FLAG_SYN:
logger.debug(
logging.debug(
f"Received ping request with value"
f"{length} for peer {self.peer_id}"
)
@ -626,7 +623,7 @@ class Yamux(IMuxedConn):
)
await self.secured_conn.write(ping_header)
elif flags & FLAG_ACK:
logger.debug(
logging.debug(
f"Received ping response with value"
f"{length} for peer {self.peer_id}"
)
@ -640,7 +637,7 @@ class Yamux(IMuxedConn):
self.stream_buffers[stream_id].extend(data)
self.stream_events[stream_id].set()
if flags & FLAG_FIN:
logger.debug(
logging.debug(
f"Received FIN for stream {self.peer_id}:"
f"{stream_id}, marking recv_closed"
)
@ -648,7 +645,7 @@ class Yamux(IMuxedConn):
if self.streams[stream_id].send_closed:
self.streams[stream_id].closed = True
except Exception as e:
logger.error(f"Error reading data for stream {stream_id}: {e}")
logging.error(f"Error reading data for stream {stream_id}: {e}")
# Mark stream as closed on read error
async with self.streams_lock:
if stream_id in self.streams:
@ -662,7 +659,7 @@ class Yamux(IMuxedConn):
if stream_id in self.streams:
stream = self.streams[stream_id]
async with stream.window_lock:
logger.debug(
logging.debug(
f"Received window update for stream"
f"{self.peer_id}:{stream_id},"
f" increment: {increment}"
@ -677,7 +674,7 @@ class Yamux(IMuxedConn):
and details.get("requested_count") == 2
and details.get("received_count") == 0
):
logger.info(
logging.info(
f"Stream closed cleanly for peer {self.peer_id}"
+ f" (IncompleteReadError: {details})"
)
@ -685,29 +682,12 @@ class Yamux(IMuxedConn):
await self._cleanup_on_error()
break
else:
logger.error(
logging.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
else:
# Handle RawConnError with more nuance
if isinstance(e, RawConnError):
error_msg = str(e)
# If RawConnError is empty, it's likely normal cleanup
if not error_msg.strip():
logger.info(
f"RawConnError (empty) during cleanup for peer "
f"{self.peer_id} (normal connection shutdown)"
)
else:
# Log non-empty RawConnError as warning
logger.warning(
f"RawConnError during connection handling for peer "
f"{self.peer_id}: {error_msg}"
)
else:
# Log all other errors normally
logger.error(
logging.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
@ -740,9 +720,9 @@ class Yamux(IMuxedConn):
# Close the secured connection
try:
await self.secured_conn.close()
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
except Exception as close_error:
logger.error(
logging.error(
f"Error closing secured_conn for peer {self.peer_id}: {close_error}"
)
@ -751,14 +731,14 @@ class Yamux(IMuxedConn):
# Call on_close callback if provided
if self.on_close:
logger.debug(f"Calling on_close for peer {self.peer_id}")
logging.debug(f"Calling on_close for peer {self.peer_id}")
try:
if inspect.iscoroutinefunction(self.on_close):
await self.on_close()
else:
self.on_close()
except Exception as callback_error:
logger.error(f"Error in on_close callback: {callback_error}")
logging.error(f"Error in on_close callback: {callback_error}")
# Cancel nursery tasks
if self._nursery:

View File

@ -9,7 +9,6 @@ from libp2p.utils.varint import (
read_varint_prefixed_bytes,
decode_varint_from_bytes,
decode_varint_with_size,
read_length_prefixed_protobuf,
)
from libp2p.utils.version import (
get_agent_version,
@ -25,5 +24,4 @@ __all__ = [
"read_varint_prefixed_bytes",
"decode_varint_from_bytes",
"decode_varint_with_size",
"read_length_prefixed_protobuf",
]

View File

@ -1,9 +1,7 @@
import itertools
import logging
import math
from typing import BinaryIO
from libp2p.abc import INetStream
from libp2p.exceptions import (
ParseError,
)
@ -27,41 +25,42 @@ HIGH_MASK = 2**7
SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7
def encode_uvarint(value: int) -> bytes:
"""Encode an unsigned integer as a varint."""
if value < 0:
raise ValueError("Cannot encode negative value as uvarint")
result = bytearray()
while value >= 0x80:
result.append((value & 0x7F) | 0x80)
value >>= 7
result.append(value & 0x7F)
return bytes(result)
def decode_uvarint(data: bytes) -> int:
"""Decode a varint from bytes."""
if not data:
raise ParseError("Unexpected end of data")
result = 0
shift = 0
for byte in data:
result |= (byte & 0x7F) << shift
if (byte & 0x80) == 0:
def encode_uvarint(number: int) -> bytes:
"""Pack `number` into varint bytes."""
buf = b""
while True:
towrite = number & 0x7F
number >>= 7
if number:
buf += bytes((towrite | 0x80,))
else:
buf += bytes((towrite,))
break
shift += 7
if shift >= 64:
raise ValueError("Varint too long")
return result
return buf
def decode_varint_from_bytes(data: bytes) -> int:
"""Decode a varint from bytes (alias for decode_uvarint for backward comp)."""
return decode_uvarint(data)
"""
Decode a varint from bytes and return the value.
This is a synchronous version of decode_uvarint_from_stream for already-read bytes.
"""
res = 0
for shift in itertools.count(0, 7):
if shift > SHIFT_64_BIT_MAX:
raise ParseError("Integer is too large...")
if not data:
raise ParseError("Unexpected end of data")
value = data[0]
data = data[1:]
res += (value & LOW_MASK) << shift
if not value & HIGH_MASK:
break
return res
async def decode_uvarint_from_stream(reader: Reader) -> int:
@ -85,33 +84,34 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
def decode_varint_with_size(data: bytes) -> tuple[int, int]:
"""
Decode a varint from bytes and return both the value and the number of bytes
consumed.
Returns:
Tuple[int, int]: (value, bytes_consumed)
Decode a varint from bytes and return (value, bytes_consumed).
Returns (0, 0) if the data doesn't start with a valid varint.
"""
result = 0
shift = 0
bytes_consumed = 0
for byte in data:
result |= (byte & 0x7F) << shift
bytes_consumed += 1
try:
# Calculate how many bytes the varint consumes
varint_size = 0
for i, byte in enumerate(data):
varint_size += 1
if (byte & 0x80) == 0:
break
shift += 7
if shift >= 64:
raise ValueError("Varint too long")
return result, bytes_consumed
if varint_size == 0:
return 0, 0
# Extract just the varint bytes
varint_bytes = data[:varint_size]
# Decode the varint
value = decode_varint_from_bytes(varint_bytes)
return value, varint_size
except Exception:
return 0, 0
def encode_varint_prefixed(data: bytes) -> bytes:
"""Encode data with a varint length prefix."""
length_bytes = encode_uvarint(len(data))
return length_bytes + data
def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
varint_len = encode_uvarint(len(msg_bytes))
return varint_len + msg_bytes
async def read_varint_prefixed_bytes(reader: Reader) -> bytes:
@ -138,95 +138,3 @@ async def read_delim(reader: Reader) -> bytes:
f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes!r}'
)
return msg_bytes[:-1]
def read_varint_prefixed_bytes_sync(
stream: BinaryIO, max_length: int = 1024 * 1024
) -> bytes:
"""
Read varint-prefixed bytes from a stream.
Args:
stream: A stream-like object with a read() method
max_length: Maximum allowed data length to prevent memory exhaustion
Returns:
bytes: The data without the length prefix
Raises:
ValueError: If the length prefix is invalid or too large
EOFError: If the stream ends unexpectedly
"""
# Read the varint length prefix
length_bytes = b""
while True:
byte_data = stream.read(1)
if not byte_data:
raise EOFError("Stream ended while reading varint length prefix")
length_bytes += byte_data
if byte_data[0] & 0x80 == 0:
break
# Decode the length
length = decode_uvarint(length_bytes)
if length > max_length:
raise ValueError(f"Data length {length} exceeds maximum allowed {max_length}")
# Read the data
data = stream.read(length)
if len(data) != length:
raise EOFError(f"Expected {length} bytes, got {len(data)}")
return data
async def read_length_prefixed_protobuf(
stream: INetStream, use_varint_format: bool = True, max_length: int = 1024 * 1024
) -> bytes:
"""Read a protobuf message from a stream, handling both formats."""
if use_varint_format:
# Read length-prefixed protobuf message from the stream
# First read the varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
raise Exception("No length prefix received")
length_bytes += b
if b[0] & 0x80 == 0:
break
msg_length = decode_varint_from_bytes(length_bytes)
if msg_length > max_length:
raise Exception(
f"Message length {msg_length} exceeds maximum allowed {max_length}"
)
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
raise Exception(
f"Incomplete message: expected {msg_length}, got {len(data)}"
)
return data
else:
# Read raw protobuf message from the stream
# For raw format, read all available data in one go
data = await stream.read()
# If we got no data, raise an exception
if not data:
raise Exception("No data received in raw format")
if len(data) > max_length:
raise Exception(
f"Message length {len(data)} exceeds maximum allowed {max_length}"
)
return data

View File

@ -1 +0,0 @@
Added `Bootstrap` peer discovery module that allows nodes to connect to predefined bootstrap peers for network discovery.

View File

@ -1 +0,0 @@
Add lock for read/write to avoid interleaving receiving messages in mplex_stream.py

View File

@ -1 +0,0 @@
[mplex] Add timeout and error handling during stream close

View File

@ -1,2 +0,0 @@
Added the `Certified Addr-Book` interface supported by `Envelope` and `PeerRecord` class.
Integrated the signed-peer-record transfer in the identify/push protocols.

View File

@ -1,2 +0,0 @@
Added throttling for async topic validators in validate_msg, enforcing a
concurrency limit to prevent resource exhaustion under heavy load.

View File

@ -1 +0,0 @@
Pin py-multiaddr dependency to specific git commit db8124e2321f316d3b7d2733c7df11d6ad9c03e6

View File

@ -1 +0,0 @@
Replace the libp2p.peer.ID cache attributes with functools.cached_property functional decorator.

View File

@ -1 +0,0 @@
Clarified the requirement for a trailing newline in newsfragments to pass lint checks.

View File

@ -1 +0,0 @@
Fixed incorrect handling of raw protobuf format in identify protocol. The identify example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats.

View File

@ -1 +0,0 @@
Fixed incorrect handling of raw protobuf format in identify push protocol. The identify push example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats.

View File

@ -1 +0,0 @@
Yamux RawConnError Logging Refactor - Improved error handling and debug logging

View File

@ -18,19 +18,12 @@ Each file should be named like `<ISSUE>.<TYPE>.rst`, where
- `performance`
- `removal`
So for example: `1024.feature.rst`
**Important**: Ensure the file ends with a newline character (`\n`) to pass GitHub tox linting checks.
```
Added support for Ed25519 key generation in libp2p peer identity creation.
```
So for example: `123.feature.rst`, `456.bugfix.rst`
If the PR fixes an issue, use that number here. If there is no issue,
then open up the PR first and use the PR number for the newsfragment.
**Note** that the `towncrier` tool will automatically
Note that the `towncrier` tool will automatically
reflow your text, so don't try to do any fancy formatting. Run
`towncrier build --draft` to get a preview of what the release notes entry
will look like in the final release notes.

View File

@ -19,11 +19,10 @@ dependencies = [
"exceptiongroup>=1.2.0; python_version < '3.11'",
"grpcio>=1.41.0",
"lru-dict>=1.1.6",
# "multiaddr>=0.0.9",
"multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@db8124e2321f316d3b7d2733c7df11d6ad9c03e6",
"multiaddr>=0.0.9",
"mypy-protobuf>=3.0.0",
"noiseprotocol>=0.3.0",
"protobuf>=4.25.0,<5.0.0",
"protobuf>=4.21.0,<5.0.0",
"pycryptodome>=3.9.2",
"pymultihash>=0.8.2",
"pynacl>=1.3.0",

View File

@ -13,8 +13,6 @@ from libp2p.identity.identify.identify import (
_multiaddr_to_bytes,
parse_identify_response,
)
from libp2p.peer.envelope import Envelope, consume_envelope, unmarshal_envelope
from libp2p.peer.peer_record import unmarshal_record
from tests.utils.factories import (
host_pair_factory,
)
@ -42,19 +40,6 @@ async def test_identify_protocol(security_protocol):
# Parse the response (handles both old and new formats)
identify_response = parse_identify_response(response)
# Validate the recieved envelope and then store it in the certified-addr-book
envelope, record = consume_envelope(
identify_response.signedPeerRecord, "libp2p-peer-record"
)
assert host_b.peerstore.consume_peer_record(envelope, ttl=7200)
# Check if the peer_id in the record is same as of host_a
assert record.peer_id == host_a.get_id()
# Check if the peer-record is correctly consumed
assert host_a.get_addrs() == host_b.peerstore.addrs(host_a.get_id())
assert isinstance(host_b.peerstore.get_peer_record(host_a.get_id()), Envelope)
logger.debug("host_a: %s", host_a.get_addrs())
logger.debug("host_b: %s", host_b.get_addrs())
@ -86,14 +71,5 @@ async def test_identify_protocol(security_protocol):
# Check protocols
assert set(identify_response.protocols) == set(host_a.get_mux().get_protocols())
# sanity check if the peer_id of the identify msg are same
assert (
unmarshal_record(
unmarshal_envelope(identify_response.signedPeerRecord).raw_payload
).peer_id
== unmarshal_record(
unmarshal_envelope(
_mk_identify_protobuf(host_a, cleaned_addr).signedPeerRecord
).raw_payload
).peer_id
)
# sanity check
assert identify_response == _mk_identify_protobuf(host_a, cleaned_addr)

View File

@ -1,241 +0,0 @@
import logging
import pytest
from libp2p.custom_types import TProtocol
from libp2p.identity.identify.identify import (
AGENT_VERSION,
ID,
PROTOCOL_VERSION,
_multiaddr_to_bytes,
identify_handler_for,
parse_identify_response,
)
from tests.utils.factories import host_pair_factory
logger = logging.getLogger("libp2p.identity.identify-integration-test")
@pytest.mark.trio
async def test_identify_protocol_varint_format_integration(security_protocol):
"""Test identify protocol with varint format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
assert result.listen_addrs == [
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
]
@pytest.mark.trio
async def test_identify_protocol_raw_format_integration(security_protocol):
"""Test identify protocol with raw format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
assert result.listen_addrs == [
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
]
@pytest.mark.trio
async def test_identify_default_format_behavior(security_protocol):
"""Test identify protocol uses correct default format."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Use default identify handler (should use varint format)
host_a.set_stream_handler(ID, identify_handler_for(host_a))
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_cross_format_compatibility_varint_to_raw(security_protocol):
"""Test varint dialer with raw listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Host A uses raw format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
# Host B makes request (will automatically detect format)
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response (should work with automatic format detection)
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_cross_format_compatibility_raw_to_varint(security_protocol):
"""Test raw dialer with varint listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Host A uses varint format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Host B makes request (will automatically detect format)
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response (should work with automatic format detection)
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_format_detection_robustness(security_protocol):
"""Test identify protocol format detection is robust with various message sizes."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Test both formats with different message sizes
for use_varint in [True, False]:
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=use_varint)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_large_message_handling(security_protocol):
"""Test identify protocol handles large messages with many protocols."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add many protocols to make the message larger
async def dummy_handler(stream):
pass
for i in range(10):
host_a.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler)
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
# Make identify request
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read(8192)
await stream.close()
# Parse response
result = parse_identify_response(response)
# Verify response content
assert result.agent_version == AGENT_VERSION
assert result.protocol_version == PROTOCOL_VERSION
assert result.public_key == host_a.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_message_equivalence_real_network(security_protocol):
"""Test that both formats produce equivalent messages in real network."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Test varint format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=True)
)
stream_varint = await host_b.new_stream(host_a.get_id(), (ID,))
response_varint = await stream_varint.read(8192)
await stream_varint.close()
# Test raw format
host_a.set_stream_handler(
ID, identify_handler_for(host_a, use_varint_format=False)
)
stream_raw = await host_b.new_stream(host_a.get_id(), (ID,))
response_raw = await stream_raw.read(8192)
await stream_raw.close()
# Parse both responses
result_varint = parse_identify_response(response_varint)
result_raw = parse_identify_response(response_raw)
# Both should produce identical parsed results
assert result_varint.agent_version == result_raw.agent_version
assert result_varint.protocol_version == result_raw.protocol_version
assert result_varint.public_key == result_raw.public_key
assert result_varint.listen_addrs == result_raw.listen_addrs

View File

@ -0,0 +1,410 @@
import pytest
from libp2p.identity.identify.identify import (
_mk_identify_protobuf,
)
from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
from libp2p.io.abc import Closer, Reader, Writer
from libp2p.utils.varint import (
decode_varint_from_bytes,
encode_varint_prefixed,
)
from tests.utils.factories import (
host_pair_factory,
)
class MockStream(Reader, Writer, Closer):
"""Mock stream for testing identify protocol compatibility."""
def __init__(self, data: bytes):
self.data = data
self.position = 0
self.closed = False
async def read(self, n: int | None = None) -> bytes:
if self.closed or self.position >= len(self.data):
return b""
if n is None:
n = len(self.data) - self.position
result = self.data[self.position : self.position + n]
self.position += len(result)
return result
async def write(self, data: bytes) -> None:
# Mock write - just store the data
pass
async def close(self) -> None:
self.closed = True
def create_identify_message(host, observed_multiaddr=None):
"""Create an identify protobuf message."""
return _mk_identify_protobuf(host, observed_multiaddr)
def create_new_format_message(identify_msg):
"""Create a new format (length-prefixed) identify message."""
msg_bytes = identify_msg.SerializeToString()
return encode_varint_prefixed(msg_bytes)
def create_old_format_message(identify_msg):
"""Create an old format (raw protobuf) identify message."""
return identify_msg.SerializeToString()
async def read_new_format_message(stream) -> bytes:
"""Read a new format (length-prefixed) identify message."""
# Read varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if not length_bytes:
raise ValueError("No length prefix received")
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
response = await stream.read(msg_length)
if len(response) != msg_length:
raise ValueError("Incomplete message received")
return response
async def read_old_format_message(stream) -> bytes:
"""Read an old format (raw protobuf) identify message."""
# Read all available data
response = b""
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
async def read_compatible_message(stream) -> bytes:
"""Read an identify message in either old or new format."""
# Try to read a few bytes to detect the format
first_bytes = await stream.read(10)
if not first_bytes:
raise ValueError("No data received")
# Try to decode as varint length prefix (new format)
try:
msg_length = decode_varint_from_bytes(first_bytes)
# Validate that the length is reasonable (not too large)
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate how many bytes the varint consumed
varint_len = 0
for i, byte in enumerate(first_bytes):
varint_len += 1
if (byte & 0x80) == 0:
break
# Read the remaining protobuf message
remaining_bytes = await stream.read(
msg_length - (len(first_bytes) - varint_len)
)
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
message_data = first_bytes[varint_len:] + remaining_bytes
# Try to parse as protobuf to validate
try:
Identify().ParseFromString(message_data)
return message_data
except Exception:
# If protobuf parsing fails, fall back to old format
pass
except Exception:
pass
# Fall back to old format (raw protobuf)
response = first_bytes
# Read more data if available
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
async def read_compatible_message_simple(stream) -> bytes:
"""Read a message in either old or new format (simplified version for testing)."""
# Try to read a few bytes to detect the format
first_bytes = await stream.read(10)
if not first_bytes:
raise ValueError("No data received")
# Try to decode as varint length prefix (new format)
try:
msg_length = decode_varint_from_bytes(first_bytes)
# Validate that the length is reasonable (not too large)
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate how many bytes the varint consumed
varint_len = 0
for i, byte in enumerate(first_bytes):
varint_len += 1
if (byte & 0x80) == 0:
break
# Read the remaining message
remaining_bytes = await stream.read(
msg_length - (len(first_bytes) - varint_len)
)
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
return first_bytes[varint_len:] + remaining_bytes
except Exception:
pass
# Fall back to old format (raw data)
response = first_bytes
# Read more data if available
while True:
chunk = await stream.read(4096)
if not chunk:
break
response += chunk
return response
def detect_format(data):
"""Detect if data is in new or old format (varint-prefixed or raw protobuf)."""
if not data:
return "unknown"
# Try to decode as varint
try:
msg_length = decode_varint_from_bytes(data)
# Validate that the length is reasonable
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
# Calculate varint length
varint_len = 0
for i, byte in enumerate(data):
varint_len += 1
if (byte & 0x80) == 0:
break
# Check if we have enough data for the message
if len(data) >= varint_len + msg_length:
# Additional check: try to parse the message as protobuf
try:
message_data = data[varint_len : varint_len + msg_length]
Identify().ParseFromString(message_data)
return "new"
except Exception:
# If protobuf parsing fails, it's probably not a valid new format
pass
except Exception:
pass
# If varint decoding fails or length is unreasonable, assume old format
return "old"
@pytest.mark.trio
async def test_identify_new_format_compatibility(security_protocol):
"""Test that identify protocol works with new format (length-prefixed) messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create new format message
new_format_data = create_new_format_message(identify_msg)
# Create mock stream with new format data
stream = MockStream(new_format_data)
# Read using new format reader
response = await read_new_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_old_format_compatibility(security_protocol):
"""Test that identify protocol works with old format (raw protobuf) messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create old format message
old_format_data = create_old_format_message(identify_msg)
# Create mock stream with old format data
stream = MockStream(old_format_data)
# Read using old format reader
response = await read_old_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_backward_compatibility_old_format(security_protocol):
"""Test backward compatibility reader with old format messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create old format message
old_format_data = create_old_format_message(identify_msg)
# Create mock stream with old format data
stream = MockStream(old_format_data)
# Read using old format reader (which should work reliably)
response = await read_old_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_backward_compatibility_new_format(security_protocol):
"""Test backward compatibility reader with new format messages."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create new format message
new_format_data = create_new_format_message(identify_msg)
# Create mock stream with new format data
stream = MockStream(new_format_data)
# Read using new format reader (which should work reliably)
response = await read_new_format_message(stream)
# Parse the response
parsed_msg = Identify()
parsed_msg.ParseFromString(response)
# Verify the message content
assert parsed_msg.protocol_version == identify_msg.protocol_version
assert parsed_msg.agent_version == identify_msg.agent_version
assert parsed_msg.public_key == identify_msg.public_key
@pytest.mark.trio
async def test_identify_format_detection(security_protocol):
"""Test that the format detection works correctly."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Test new format detection
new_format_data = create_new_format_message(identify_msg)
format_type = detect_format(new_format_data)
assert format_type == "new", "New format should be detected correctly"
# Test old format detection
old_format_data = create_old_format_message(identify_msg)
format_type = detect_format(old_format_data)
assert format_type == "old", "Old format should be detected correctly"
@pytest.mark.trio
async def test_identify_error_handling(security_protocol):
"""Test error handling for malformed messages."""
from libp2p.exceptions import ParseError
# Test with empty data
stream = MockStream(b"")
with pytest.raises(ValueError, match="No data received"):
await read_compatible_message(stream)
# Test with incomplete varint
stream = MockStream(b"\x80") # Incomplete varint
with pytest.raises(ParseError, match="Unexpected end of data"):
await read_new_format_message(stream)
# Test with invalid protobuf data
stream = MockStream(b"\x05invalid") # Length prefix but invalid protobuf
with pytest.raises(Exception): # Should fail when parsing protobuf
response = await read_new_format_message(stream)
Identify().ParseFromString(response)
@pytest.mark.trio
async def test_identify_message_equivalence(security_protocol):
"""Test that old and new format messages are equivalent."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create identify message
identify_msg = create_identify_message(host_a)
# Create both formats
new_format_data = create_new_format_message(identify_msg)
old_format_data = create_old_format_message(identify_msg)
# Extract the protobuf message from new format
varint_len = 0
for i, byte in enumerate(new_format_data):
varint_len += 1
if (byte & 0x80) == 0:
break
new_format_protobuf = new_format_data[varint_len:]
# The protobuf messages should be identical
assert new_format_protobuf == old_format_data, (
"Protobuf messages should be identical in both formats"
)

View File

@ -1,552 +0,0 @@
import logging
import pytest
import trio
from libp2p.custom_types import TProtocol
from libp2p.identity.identify_push.identify_push import (
ID_PUSH,
identify_push_handler_for,
push_identify_to_peer,
push_identify_to_peers,
)
from tests.utils.factories import host_pair_factory
logger = logging.getLogger("libp2p.identity.identify-push-integration-test")
@pytest.mark.trio
async def test_identify_push_protocol_varint_format_integration(security_protocol):
"""Test identify/push protocol with varint format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to host_b so it has something to push
async def dummy_handler(stream):
pass
host_b.set_stream_handler(TProtocol("/test/protocol/1"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/2"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
)
# Push identify information from host_b to host_a
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that addresses were added
addrs = peerstore_a.addrs(peer_id_b)
assert len(addrs) > 0
# Check that protocols were added
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
# The protocols should include the dummy protocols we added
assert len(protocols) >= 2 # Should include the dummy protocols
@pytest.mark.trio
async def test_identify_push_protocol_raw_format_integration(security_protocol):
"""Test identify/push protocol with raw format in real network scenario."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=False)
)
# Push identify information from host_b to host_a
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=False)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that addresses were added
addrs = peerstore_a.addrs(peer_id_b)
assert len(addrs) > 0
# Check that protocols were added
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) >= 1 # Should include the dummy protocol
@pytest.mark.trio
async def test_identify_push_default_format_behavior(security_protocol):
"""Test identify/push protocol uses correct default format."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Use default identify/push handler (should use varint format)
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Push identify information from host_b to host_a
await push_identify_to_peer(host_b, host_a.get_id())
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that protocols were added
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) >= 1 # Should include the dummy protocol
@pytest.mark.trio
async def test_identify_push_cross_format_compatibility_varint_to_raw(
security_protocol,
):
"""Test varint pusher with raw listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Use an event to signal when handler is ready
handler_ready = trio.Event()
# Create a wrapper handler that signals when ready
original_handler = identify_push_handler_for(host_a, use_varint_format=False)
async def wrapped_handler(stream):
handler_ready.set() # Signal that handler is ready
await original_handler(stream)
# Host A uses raw format with wrapped handler
host_a.set_stream_handler(ID_PUSH, wrapped_handler)
# Host B pushes with varint format (should fail gracefully)
success = await push_identify_to_peer(
host_b, host_a.get_id(), use_varint_format=True
)
# This should fail due to format mismatch
# Note: The format detection might be more robust than expected
# so we just check that the operation completes
assert isinstance(success, bool)
@pytest.mark.trio
async def test_identify_push_cross_format_compatibility_raw_to_varint(
security_protocol,
):
"""Test raw pusher with varint listener compatibility."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Use an event to signal when handler is ready
handler_ready = trio.Event()
# Create a wrapper handler that signals when ready
original_handler = identify_push_handler_for(host_a, use_varint_format=True)
async def wrapped_handler(stream):
handler_ready.set() # Signal that handler is ready
await original_handler(stream)
# Host A uses varint format with wrapped handler
host_a.set_stream_handler(ID_PUSH, wrapped_handler)
# Host B pushes with raw format (should fail gracefully)
success = await push_identify_to_peer(
host_b, host_a.get_id(), use_varint_format=False
)
# This should fail due to format mismatch
# Note: The format detection might be more robust than expected
# so we just check that the operation completes
assert isinstance(success, bool)
@pytest.mark.trio
async def test_identify_push_multiple_peers_integration(security_protocol):
"""Test identify/push protocol with multiple peers."""
# Create two hosts using the factory
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create a third host following the pattern from test_identify_push.py
import multiaddr
from libp2p import new_host
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.peer.peerinfo import info_from_p2p_addr
# Create a new key pair for host_c
key_pair_c = create_new_key_pair()
host_c = new_host(key_pair=key_pair_c)
# Set up identify/push handlers on all hosts
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
host_b.set_stream_handler(ID_PUSH, identify_push_handler_for(host_b))
host_c.set_stream_handler(ID_PUSH, identify_push_handler_for(host_c))
# Start listening on a random port using the run context manager
listen_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
async with host_c.run([listen_addr]):
# Connect host_c to host_a and host_b using the correct pattern
await host_c.connect(info_from_p2p_addr(host_a.get_addrs()[0]))
await host_c.connect(info_from_p2p_addr(host_b.get_addrs()[0]))
# Push identify information from host_a to all connected peers
await push_identify_to_peers(host_a)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Check that host_b's peerstore has been updated with host_a's information
peerstore_b = host_b.get_peerstore()
peer_id_a = host_a.get_id()
# Check that the peer is in the peerstore
assert peer_id_a in peerstore_b.peer_ids()
# Check that host_c's peerstore has been updated with host_a's information
peerstore_c = host_c.get_peerstore()
# Check that the peer is in the peerstore
assert peer_id_a in peerstore_c.peer_ids()
# Test for push_identify to only connected peers and not all peers
# Disconnect a from c.
await host_c.disconnect(host_a.get_id())
await push_identify_to_peers(host_c)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Check that host_a's peerstore has not been updated with host_c's info
assert host_c.get_id() not in host_a.get_peerstore().peer_ids()
# Check that host_b's peerstore has been updated with host_c's info
assert host_c.get_id() in host_b.get_peerstore().peer_ids()
@pytest.mark.trio
async def test_identify_push_large_message_handling(security_protocol):
"""Test identify/push protocol handles large messages with many protocols."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add many protocols to make the message larger
async def dummy_handler(stream):
pass
for i in range(10):
host_b.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler)
# Also add some protocols to host_a to ensure it has protocols to push
for i in range(5):
host_a.set_stream_handler(TProtocol(f"/test/protocol/a{i}"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
)
# Push identify information from host_b to host_a
success = await push_identify_to_peer(
host_b, host_a.get_id(), use_varint_format=True
)
assert success
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated with all protocols
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) >= 10 # Should include the dummy protocols
@pytest.mark.trio
async def test_identify_push_peerstore_update_completeness(security_protocol):
"""Test that identify/push updates all relevant peerstore information."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Push identify information from host_b to host_a
await push_identify_to_peer(host_b, host_a.get_id())
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that protocols were added
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) > 0
# Check that addresses were added
addrs = peerstore_a.addrs(peer_id_b)
assert len(addrs) > 0
# Check that public key was added
pubkey = peerstore_a.pubkey(peer_id_b)
assert pubkey is not None
assert pubkey.serialize() == host_b.get_public_key().serialize()
@pytest.mark.trio
async def test_identify_push_concurrent_requests(security_protocol):
"""Test identify/push protocol handles concurrent requests properly."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Make multiple concurrent push requests
results = []
async def push_identify():
result = await push_identify_to_peer(host_b, host_a.get_id())
results.append(result)
# Run multiple concurrent pushes using nursery
async with trio.open_nursery() as nursery:
for _ in range(3):
nursery.start_soon(push_identify)
# All should succeed
assert len(results) == 3
assert all(results)
# Wait a bit for the pushes to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) > 0
@pytest.mark.trio
async def test_identify_push_stream_handling(security_protocol):
"""Test identify/push protocol properly handles stream lifecycle."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Push identify information from host_b to host_a
success = await push_identify_to_peer(host_b, host_a.get_id())
assert success
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
protocols = peerstore_a.get_protocols(peer_id_b)
assert protocols is not None
assert len(protocols) > 0
@pytest.mark.trio
async def test_identify_push_error_handling(security_protocol):
"""Test identify/push protocol handles errors gracefully."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Create a handler that raises an exception but catches it to prevent test
# failure
async def error_handler(stream):
try:
await stream.close()
raise Exception("Test error")
except Exception:
# Catch the exception to prevent it from propagating up
pass
host_a.set_stream_handler(ID_PUSH, error_handler)
# Push should complete (message sent) but handler should fail gracefully
success = await push_identify_to_peer(host_b, host_a.get_id())
assert success # The push operation itself succeeds (message sent)
# Wait a bit for the handler to process
await trio.sleep(0.1)
# Verify that the error was handled gracefully (no test failure)
# The handler caught the exception and didn't propagate it
@pytest.mark.trio
async def test_identify_push_message_equivalence_real_network(security_protocol):
"""Test that both formats produce equivalent peerstore updates in real network."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Test varint format
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True)
)
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Get peerstore state after varint push
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
protocols_varint = peerstore_a.get_protocols(peer_id_b)
addrs_varint = peerstore_a.addrs(peer_id_b)
# Clear peerstore for next test
peerstore_a.clear_addrs(peer_id_b)
peerstore_a.clear_protocol_data(peer_id_b)
# Test raw format
host_a.set_stream_handler(
ID_PUSH, identify_push_handler_for(host_a, use_varint_format=False)
)
await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=False)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Get peerstore state after raw push
protocols_raw = peerstore_a.get_protocols(peer_id_b)
addrs_raw = peerstore_a.addrs(peer_id_b)
# Both should produce equivalent peerstore updates
# Check that both formats successfully updated protocols
assert protocols_varint is not None
assert protocols_raw is not None
assert len(protocols_varint) > 0
assert len(protocols_raw) > 0
# Check that both formats successfully updated addresses
assert addrs_varint is not None
assert addrs_raw is not None
assert len(addrs_varint) > 0
assert len(addrs_raw) > 0
# Both should contain the same essential information
# (exact address lists might differ due to format-specific handling)
assert set(protocols_varint) == set(protocols_raw)
@pytest.mark.trio
async def test_identify_push_with_observed_address(security_protocol):
"""Test identify/push protocol includes observed address information."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Add some protocols to both hosts
async def dummy_handler(stream):
pass
host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler)
host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler)
# Set up identify/push handler on host_a
host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a))
# Get host_b's address as observed by host_a
from multiaddr import Multiaddr
host_b_addr = host_b.get_addrs()[0]
observed_multiaddr = Multiaddr(str(host_b_addr))
# Push identify information with observed address
await push_identify_to_peer(
host_b, host_a.get_id(), observed_multiaddr=observed_multiaddr
)
# Wait a bit for the push to complete
await trio.sleep(0.1)
# Verify that host_a's peerstore was updated
peerstore_a = host_a.get_peerstore()
peer_id_b = host_b.get_id()
# Check that addresses were added
addrs = peerstore_a.addrs(peer_id_b)
assert len(addrs) > 0
# The observed address should be among the stored addresses
addr_strings = [str(addr) for addr in addrs]
assert str(observed_multiaddr) in addr_strings

View File

@ -3,10 +3,7 @@ from multiaddr import (
Multiaddr,
)
from libp2p.crypto.rsa import create_new_key_pair
from libp2p.peer.envelope import Envelope, seal_record
from libp2p.peer.id import ID
from libp2p.peer.peer_record import PeerRecord
from libp2p.peer.peerstore import (
PeerStore,
PeerStoreError,
@ -87,53 +84,3 @@ def test_peers_with_addrs():
store.clear_addrs(ID(b"peer2"))
assert set(store.peers_with_addrs()) == {ID(b"peer3")}
def test_ceritified_addr_book():
store = PeerStore()
key_pair = create_new_key_pair()
peer_id = ID.from_pubkey(key_pair.public_key)
addrs = [
Multiaddr("/ip4/127.0.0.1/tcp/9000"),
Multiaddr("/ip4/127.0.0.1/tcp/9001"),
]
ttl = 60
# Construct signed PereRecord
record = PeerRecord(peer_id, addrs, 21)
envelope = seal_record(record, key_pair.private_key)
result = store.consume_peer_record(envelope, ttl)
assert result is True
# Retrieve the record
retrieved = store.get_peer_record(peer_id)
assert retrieved is not None
assert isinstance(retrieved, Envelope)
addr_list = store.addrs(peer_id)
assert set(addr_list) == set(addrs)
# Now try to push an older record (should be rejected)
old_record = PeerRecord(peer_id, [Multiaddr("/ip4/10.0.0.1/tcp/4001")], 20)
old_envelope = seal_record(old_record, key_pair.private_key)
result = store.consume_peer_record(old_envelope, ttl)
assert result is False
# Push a new record (should override)
new_addrs = [Multiaddr("/ip4/192.168.0.1/tcp/5001")]
new_record = PeerRecord(peer_id, new_addrs, 23)
new_envelope = seal_record(new_record, key_pair.private_key)
result = store.consume_peer_record(new_envelope, ttl)
assert result is True
# Confirm the record is updated
latest = store.get_peer_record(peer_id)
assert isinstance(latest, Envelope)
assert latest.record().seq == 23
# Merged addresses = old addres + new_addrs
expected_addrs = set(new_addrs)
actual_addrs = set(store.addrs(peer_id))
assert actual_addrs == expected_addrs

View File

@ -1,129 +0,0 @@
from multiaddr import Multiaddr
from libp2p.crypto.rsa import (
create_new_key_pair,
)
from libp2p.peer.envelope import (
Envelope,
consume_envelope,
make_unsigned,
seal_record,
unmarshal_envelope,
)
from libp2p.peer.id import ID
import libp2p.peer.pb.crypto_pb2 as crypto_pb
import libp2p.peer.pb.envelope_pb2 as env_pb
from libp2p.peer.peer_record import PeerRecord
DOMAIN = "libp2p-peer-record"
def test_basic_protobuf_serialization_deserialization():
pubkey = crypto_pb.PublicKey()
pubkey.Type = crypto_pb.KeyType.Ed25519
pubkey.Data = b"\x01\x02\x03"
env = env_pb.Envelope()
env.public_key.CopyFrom(pubkey)
env.payload_type = b"\x03\x01"
env.payload = b"test-payload"
env.signature = b"signature-bytes"
serialized = env.SerializeToString()
new_env = env_pb.Envelope()
new_env.ParseFromString(serialized)
assert new_env.public_key.Type == crypto_pb.KeyType.Ed25519
assert new_env.public_key.Data == b"\x01\x02\x03"
assert new_env.payload_type == b"\x03\x01"
assert new_env.payload == b"test-payload"
assert new_env.signature == b"signature-bytes"
def test_enevelope_marshal_unmarshal_roundtrip():
keypair = create_new_key_pair()
pubkey = keypair.public_key
private_key = keypair.private_key
payload_type = b"\x03\x01"
payload = b"test-record"
sig = private_key.sign(make_unsigned(DOMAIN, payload_type, payload))
env = Envelope(pubkey, payload_type, payload, sig)
serialized = env.marshal_envelope()
new_env = unmarshal_envelope(serialized)
assert new_env.public_key == pubkey
assert new_env.payload_type == payload_type
assert new_env.raw_payload == payload
assert new_env.signature == sig
def test_seal_and_consume_envelope_roundtrip():
keypair = create_new_key_pair()
priv_key = keypair.private_key
pub_key = keypair.public_key
peer_id = ID.from_pubkey(pub_key)
addrs = [Multiaddr("/ip4/127.0.0.1/tcp/4001"), Multiaddr("/ip4/127.0.0.1/tcp/4002")]
seq = 12345
record = PeerRecord(peer_id=peer_id, addrs=addrs, seq=seq)
# Seal
envelope = seal_record(record, priv_key)
serialized = envelope.marshal_envelope()
# Consume
env, rec = consume_envelope(serialized, record.domain())
# Assertions
assert env.public_key == pub_key
assert rec.peer_id == peer_id
assert rec.seq == seq
assert rec.addrs == addrs
def test_envelope_equal():
# Create a new keypair
keypair = create_new_key_pair()
private_key = keypair.private_key
# Create a mock PeerRecord
record = PeerRecord(
peer_id=ID.from_base58("QmNM23MiU1Kd7yfiKVdUnaDo8RYca8By4zDmr7uSaVV8Px"),
seq=1,
addrs=[Multiaddr("/ip4/127.0.0.1/tcp/4001")],
)
# Seal it into an Envelope
env1 = seal_record(record, private_key)
# Create a second identical envelope
env2 = Envelope(
public_key=env1.public_key,
payload_type=env1.payload_type,
raw_payload=env1.raw_payload,
signature=env1.signature,
)
# They should be equal
assert env1.equal(env2)
# Now change something — payload type
env2.payload_type = b"\x99\x99"
assert not env1.equal(env2)
# Restore payload_type but change signature
env2.payload_type = env1.payload_type
env2.signature = b"wrong-signature"
assert not env1.equal(env2)
# Restore signature but change payload
env2.signature = env1.signature
env2.raw_payload = b"tampered"
assert not env1.equal(env2)
# Finally, test with a non-envelope object
assert not env1.equal("not-an-envelope")

View File

@ -1,112 +0,0 @@
import time
from multiaddr import Multiaddr
from libp2p.peer.id import ID
import libp2p.peer.pb.peer_record_pb2 as pb
from libp2p.peer.peer_record import (
PeerRecord,
addrs_from_protobuf,
peer_record_from_protobuf,
unmarshal_record,
)
# Testing methods from PeerRecord base class and PeerRecord protobuf:
def test_basic_protobuf_serializatrion_deserialization():
record = pb.PeerRecord()
record.seq = 1
serialized = record.SerializeToString()
new_record = pb.PeerRecord()
new_record.ParseFromString(serialized)
assert new_record.seq == 1
def test_timestamp_seq_monotonicity():
rec1 = PeerRecord()
time.sleep(1)
rec2 = PeerRecord()
assert isinstance(rec1.seq, int)
assert isinstance(rec2.seq, int)
assert rec2.seq > rec1.seq, f"Expected seq2 ({rec2.seq}) > seq1 ({rec1.seq})"
def test_addrs_from_protobuf_multiple_addresses():
ma1 = Multiaddr("/ip4/127.0.0.1/tcp/4001")
ma2 = Multiaddr("/ip4/127.0.0.1/tcp/4002")
addr_info1 = pb.PeerRecord.AddressInfo()
addr_info1.multiaddr = ma1.to_bytes()
addr_info2 = pb.PeerRecord.AddressInfo()
addr_info2.multiaddr = ma2.to_bytes()
result = addrs_from_protobuf([addr_info1, addr_info2])
assert result == [ma1, ma2]
def test_peer_record_from_protobuf():
peer_id = ID.from_base58("QmNM23MiU1Kd7yfiKVdUnaDo8RYca8By4zDmr7uSaVV8Px")
record = pb.PeerRecord()
record.peer_id = peer_id.to_bytes()
record.seq = 42
for addr_str in ["/ip4/127.0.0.1/tcp/4001", "/ip4/127.0.0.1/tcp/4002"]:
ma = Multiaddr(addr_str)
addr_info = pb.PeerRecord.AddressInfo()
addr_info.multiaddr = ma.to_bytes()
record.addresses.append(addr_info)
result = peer_record_from_protobuf(record)
assert result.peer_id == peer_id
assert result.seq == 42
assert len(result.addrs) == 2
assert str(result.addrs[0]) == "/ip4/127.0.0.1/tcp/4001"
assert str(result.addrs[1]) == "/ip4/127.0.0.1/tcp/4002"
def test_to_protobuf_generates_correct_message():
peer_id = ID.from_base58("QmNM23MiU1Kd7yfiKVdUnaDo8RYca8By4zDmr7uSaVV8Px")
addrs = [Multiaddr("/ip4/127.0.0.1/tcp/4001")]
seq = 12345
record = PeerRecord(peer_id, addrs, seq)
proto = record.to_protobuf()
assert isinstance(proto, pb.PeerRecord)
assert proto.peer_id == peer_id.to_bytes()
assert proto.seq == seq
assert len(proto.addresses) == 1
assert proto.addresses[0].multiaddr == addrs[0].to_bytes()
def test_unmarshal_record_roundtrip():
record = PeerRecord(
peer_id=ID.from_base58("QmNM23MiU1Kd7yfiKVdUnaDo8RYca8By4zDmr7uSaVV8Px"),
addrs=[Multiaddr("/ip4/127.0.0.1/tcp/4001")],
seq=999,
)
serialized = record.to_protobuf().SerializeToString()
deserialized = unmarshal_record(serialized)
assert deserialized.peer_id == record.peer_id
assert deserialized.seq == record.seq
assert len(deserialized.addrs) == 1
assert deserialized.addrs[0] == record.addrs[0]
def test_marshal_record_and_equal():
peer_id = ID.from_base58("QmNM23MiU1Kd7yfiKVdUnaDo8RYca8By4zDmr7uSaVV8Px")
addrs = [Multiaddr("/ip4/127.0.0.1/tcp/4001")]
original = PeerRecord(peer_id, addrs)
serialized = original.marshal_record()
deserailzed = unmarshal_record(serialized)
assert original.equal(deserailzed)

View File

@ -120,30 +120,3 @@ async def test_addr_stream_yields_new_addrs():
nursery.cancel_scope.cancel()
assert collected == [addr1, addr2]
@pytest.mark.trio
async def test_cleanup_task_remove_expired_data():
store = PeerStore()
peer_id = ID(b"peer123")
addr = Multiaddr("/ip4/127.0.0.1/tcp/4040")
# Insert addrs with short TTL (0.01s)
store.add_addr(peer_id, addr, 1)
assert store.addrs(peer_id) == [addr]
assert peer_id in store.peer_data_map
# Start cleanup task in a nursery
async with trio.open_nursery() as nursery:
# Run the cleanup task with a short interval so it runs soon
nursery.start_soon(store.start_cleanup_task, 1)
# Sleep long enough for TTL to expire and cleanup to run
await trio.sleep(3)
# Cancel the nursery to stop background tasks
nursery.cancel_scope.cancel()
# Confirm the peer data is gone from the peer_data_map
assert peer_id not in store.peer_data_map

View File

@ -5,12 +5,10 @@ import inspect
from typing import (
NamedTuple,
)
from unittest.mock import patch
import pytest
import trio
from libp2p.custom_types import AsyncValidatorFn
from libp2p.exceptions import (
ValidationError,
)
@ -245,37 +243,7 @@ async def test_get_msg_validators():
((False, True), (True, False), (True, True)),
)
@pytest.mark.trio
async def test_validate_msg_with_throttle_condition(
is_topic_1_val_passed, is_topic_2_val_passed
):
CONCURRENCY_LIMIT = 10
state = {
"concurrency_counter": 0,
"max_observed": 0,
}
lock = trio.Lock()
async def mock_run_async_validator(
self,
func: AsyncValidatorFn,
msg_forwarder: ID,
msg: rpc_pb2.Message,
results: list[bool],
) -> None:
async with self._validator_semaphore:
async with lock:
state["concurrency_counter"] += 1
if state["concurrency_counter"] > state["max_observed"]:
state["max_observed"] = state["concurrency_counter"]
try:
result = await func(msg_forwarder, msg)
results.append(result)
finally:
async with lock:
state["concurrency_counter"] -= 1
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
@ -312,20 +280,12 @@ async def test_validate_msg_with_throttle_condition(
seqno=b"\x00" * 8,
)
with patch(
"libp2p.pubsub.pubsub.Pubsub._run_async_validator",
new=mock_run_async_validator,
):
if is_topic_1_val_passed and is_topic_2_val_passed:
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
else:
with pytest.raises(ValidationError):
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
assert state["max_observed"] <= CONCURRENCY_LIMIT, (
f"Max concurrency observed: {state['max_observed']}"
)
@pytest.mark.trio
async def test_continuously_read_stream(monkeypatch, nursery, security_protocol):

View File

@ -8,7 +8,6 @@ from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed,
MplexStreamEOF,
MplexStreamReset,
MuxedConnUnavailable,
)
from libp2p.stream_muxer.mplex.mplex import (
MPLEX_MESSAGE_CHANNEL_SIZE,
@ -214,39 +213,3 @@ async def test_mplex_stream_reset(mplex_stream_pair):
# `reset` should do nothing as well.
await stream_0.reset()
await stream_1.reset()
@pytest.mark.trio
async def test_mplex_stream_close_timeout(monkeypatch, mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
# (simulate hanging)
async def fake_send_message(*args, **kwargs):
await trio.sleep_forever()
monkeypatch.setattr(stream_0.muxed_conn, "send_message", fake_send_message)
with pytest.raises(TimeoutError):
await stream_0.close()
@pytest.mark.trio
async def test_mplex_stream_close_mux_unavailable(monkeypatch, mplex_stream_pair):
stream_0, _ = mplex_stream_pair
# Patch send_message to raise MuxedConnUnavailable
def raise_unavailable(*args, **kwargs):
raise MuxedConnUnavailable("Simulated conn drop")
monkeypatch.setattr(stream_0.muxed_conn, "send_message", raise_unavailable)
# Case 1: Mplex is shutting down — should not raise
stream_0.muxed_conn.event_shutting_down.set()
await stream_0.close() # Should NOT raise
# Case 2: Mplex is NOT shutting down — should raise RuntimeError
stream_0.event_local_closed = trio.Event() # Reset since it was set in first call
stream_0.muxed_conn.event_shutting_down = trio.Event() # Unset the shutdown flag
with pytest.raises(RuntimeError, match="Failed to send close message"):
await stream_0.close()

View File

@ -1,590 +0,0 @@
from typing import Any, cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import trio
from trio.testing import wait_all_tasks_blocked
from libp2p.stream_muxer.exceptions import (
MuxedConnUnavailable,
)
from libp2p.stream_muxer.mplex.constants import HeaderTags
from libp2p.stream_muxer.mplex.datastructures import StreamID
from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed,
MplexStreamEOF,
MplexStreamReset,
)
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
class MockMuxedConn:
"""A mock Mplex connection for testing purposes."""
def __init__(self):
self.sent_messages = []
self.streams: dict[StreamID, MplexStream] = {}
self.streams_lock = trio.Lock()
self.is_unavailable = False
async def send_message(
self, flag: HeaderTags, data: bytes | None, stream_id: StreamID
) -> None:
"""Mocks sending a message over the connection."""
if self.is_unavailable:
raise MuxedConnUnavailable("Connection is unavailable")
self.sent_messages.append((flag, data, stream_id))
# Yield to allow other tasks to run
await trio.lowlevel.checkpoint()
def get_remote_address(self) -> tuple[str, int]:
"""Mocks getting the remote address."""
return "127.0.0.1", 4001
@pytest.fixture
async def mplex_stream():
"""Provides a fully initialized MplexStream and its communication channels."""
# Use a buffered channel to prevent deadlocks in simple tests
send_chan, recv_chan = trio.open_memory_channel(10)
stream_id = StreamID(1, is_initiator=True)
muxed_conn = MockMuxedConn()
stream = MplexStream("test-stream", stream_id, cast(Any, muxed_conn), recv_chan)
muxed_conn.streams[stream_id] = stream
yield stream, send_chan, muxed_conn
# Cleanup: Close channels and reset stream state
await send_chan.aclose()
await recv_chan.aclose()
# Reset stream state to prevent cross-test contamination
stream.event_local_closed = trio.Event()
stream.event_remote_closed = trio.Event()
stream.event_reset = trio.Event()
# ===============================================
# 1. Tests for Stream-Level Lock Integration
# ===============================================
@pytest.mark.trio
async def test_stream_write_is_protected_by_rwlock(mplex_stream):
"""Verify that stream.write() acquires and releases the write lock."""
stream, _, muxed_conn = mplex_stream
# Mock lock methods
original_acquire = stream.rw_lock.acquire_write
original_release = stream.rw_lock.release_write
stream.rw_lock.acquire_write = AsyncMock(wraps=original_acquire)
stream.rw_lock.release_write = MagicMock(wraps=original_release)
await stream.write(b"test data")
stream.rw_lock.acquire_write.assert_awaited_once()
stream.rw_lock.release_write.assert_called_once()
# Verify the message was actually sent
assert len(muxed_conn.sent_messages) == 1
flag, data, stream_id = muxed_conn.sent_messages[0]
assert flag == HeaderTags.MessageInitiator
assert data == b"test data"
assert stream_id == stream.stream_id
@pytest.mark.trio
async def test_stream_read_is_protected_by_rwlock(mplex_stream):
"""Verify that stream.read() acquires and releases the read lock."""
stream, send_chan, _ = mplex_stream
# Mock lock methods
original_acquire = stream.rw_lock.acquire_read
original_release = stream.rw_lock.release_read
stream.rw_lock.acquire_read = AsyncMock(wraps=original_acquire)
stream.rw_lock.release_read = AsyncMock(wraps=original_release)
await send_chan.send(b"hello")
result = await stream.read(5)
stream.rw_lock.acquire_read.assert_awaited_once()
stream.rw_lock.release_read.assert_awaited_once()
assert result == b"hello"
@pytest.mark.trio
async def test_multiple_readers_can_coexist(mplex_stream):
"""Verify multiple readers can operate concurrently."""
stream, send_chan, _ = mplex_stream
# Send enough data for both reads
await send_chan.send(b"data1")
await send_chan.send(b"data2")
# Track lock acquisition order
acquisition_order = []
release_order = []
# Patch lock methods to track concurrency
original_acquire = stream.rw_lock.acquire_read
original_release = stream.rw_lock.release_read
async def tracked_acquire():
nonlocal acquisition_order
acquisition_order.append("start")
await original_acquire()
acquisition_order.append("acquired")
async def tracked_release():
nonlocal release_order
release_order.append("start")
await original_release()
release_order.append("released")
with (
patch.object(
stream.rw_lock, "acquire_read", side_effect=tracked_acquire, autospec=True
),
patch.object(
stream.rw_lock, "release_read", side_effect=tracked_release, autospec=True
),
):
# Execute concurrent reads
async with trio.open_nursery() as nursery:
nursery.start_soon(stream.read, 5)
nursery.start_soon(stream.read, 5)
# Verify both reads happened
assert acquisition_order.count("start") == 2
assert acquisition_order.count("acquired") == 2
assert release_order.count("start") == 2
assert release_order.count("released") == 2
@pytest.mark.trio
async def test_writer_blocks_readers(mplex_stream):
"""Verify that a writer blocks all readers and new readers queue behind."""
stream, send_chan, _ = mplex_stream
writer_acquired = trio.Event()
readers_ready = trio.Event()
writer_finished = trio.Event()
all_readers_started = trio.Event()
all_readers_done = trio.Event()
counters = {"reader_start_count": 0, "reader_done_count": 0}
reader_target = 3
reader_start_lock = trio.Lock()
# Patch write lock to control test flow
original_acquire_write = stream.rw_lock.acquire_write
original_release_write = stream.rw_lock.release_write
async def tracked_acquire_write():
await original_acquire_write()
writer_acquired.set()
# Wait for readers to queue up
await readers_ready.wait()
# Must be synchronous since real release_write is sync
def tracked_release_write():
original_release_write()
writer_finished.set()
with (
patch.object(
stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write
),
patch.object(
stream.rw_lock, "release_write", side_effect=tracked_release_write
),
):
async with trio.open_nursery() as nursery:
# Start writer
nursery.start_soon(stream.write, b"test")
await writer_acquired.wait()
# Start readers
async def reader_task():
async with reader_start_lock:
counters["reader_start_count"] += 1
if counters["reader_start_count"] == reader_target:
all_readers_started.set()
try:
# This will block until data is available
await stream.read(5)
except (MplexStreamReset, MplexStreamEOF):
pass
finally:
async with reader_start_lock:
counters["reader_done_count"] += 1
if counters["reader_done_count"] == reader_target:
all_readers_done.set()
for _ in range(reader_target):
nursery.start_soon(reader_task)
# Wait until all readers are started
await all_readers_started.wait()
# Let the writer finish and release the lock
readers_ready.set()
await writer_finished.wait()
# Send data to unblock the readers
for i in range(reader_target):
await send_chan.send(b"data" + str(i).encode())
# Wait for all readers to finish
await all_readers_done.wait()
@pytest.mark.trio
async def test_writer_waits_for_readers(mplex_stream):
"""Verify a writer waits for existing readers to complete."""
stream, send_chan, _ = mplex_stream
readers_started = trio.Event()
writer_entered = trio.Event()
writer_acquiring = trio.Event()
readers_finished = trio.Event()
# Send data for readers
await send_chan.send(b"data1")
await send_chan.send(b"data2")
# Patch read lock to control test flow
original_acquire_read = stream.rw_lock.acquire_read
async def tracked_acquire_read():
await original_acquire_read()
readers_started.set()
# Wait until readers are allowed to finish
await readers_finished.wait()
# Patch write lock to detect when writer is blocked
original_acquire_write = stream.rw_lock.acquire_write
async def tracked_acquire_write():
writer_acquiring.set()
await original_acquire_write()
writer_entered.set()
with (
patch.object(stream.rw_lock, "acquire_read", side_effect=tracked_acquire_read),
patch.object(
stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write
),
):
async with trio.open_nursery() as nursery:
# Start readers
nursery.start_soon(stream.read, 5)
nursery.start_soon(stream.read, 5)
# Wait for at least one reader to acquire the lock
await readers_started.wait()
# Start writer (should block)
nursery.start_soon(stream.write, b"test")
# Wait for writer to start acquiring lock
await writer_acquiring.wait()
# Verify writer hasn't entered critical section
assert not writer_entered.is_set()
# Allow readers to finish
readers_finished.set()
# Verify writer can proceed
await writer_entered.wait()
@pytest.mark.trio
async def test_lock_behavior_during_cancellation(mplex_stream):
"""Verify that a lock is released when a task holding it is cancelled."""
stream, _, _ = mplex_stream
reader_acquired_lock = trio.Event()
async def cancellable_reader(task_status):
async with stream.rw_lock.read_lock():
reader_acquired_lock.set()
task_status.started()
# Wait indefinitely until cancelled.
await trio.sleep_forever()
async with trio.open_nursery() as nursery:
# Start the reader and wait for it to acquire the lock.
await nursery.start(cancellable_reader)
await reader_acquired_lock.wait()
# Now that the reader has the lock, cancel the nursery.
# This will cancel the reader task, and its lock should be released.
nursery.cancel_scope.cancel()
# After the nursery is cancelled, the reader should have released the lock.
# To verify, we try to acquire a write lock. If the read lock was not
# released, this will time out.
with trio.move_on_after(1) as cancel_scope:
async with stream.rw_lock.write_lock():
pass
if cancel_scope.cancelled_caught:
pytest.fail(
"Write lock could not be acquired after a cancelled reader, "
"indicating the read lock was not released."
)
@pytest.mark.trio
async def test_concurrent_read_write_sequence(mplex_stream):
"""Verify complex sequence of interleaved reads and writes."""
stream, send_chan, _ = mplex_stream
results = []
# Use a mock to intercept writes and feed them back to the read channel
original_write = stream.write
reader1_finished = trio.Event()
writer1_finished = trio.Event()
reader2_finished = trio.Event()
async def mocked_write(data: bytes) -> None:
await original_write(data)
# Simulate the other side receiving the data and sending a response
# by putting data into the read channel.
await send_chan.send(data)
with patch.object(stream, "write", wraps=mocked_write) as patched_write:
async with trio.open_nursery() as nursery:
# Test scenario:
# 1. Reader 1 starts, waits for data.
# 2. Writer 1 writes, which gets fed back to the stream.
# 3. Reader 2 starts, reads what Writer 1 wrote.
# 4. Writer 2 writes.
async def reader1():
nonlocal results
results.append("R1 start")
data = await stream.read(5)
results.append(data)
results.append("R1 done")
reader1_finished.set()
async def writer1():
nonlocal results
await reader1_finished.wait()
results.append("W1 start")
await stream.write(b"write1")
results.append("W1 done")
writer1_finished.set()
async def reader2():
nonlocal results
await writer1_finished.wait()
# This will read the data from writer1
results.append("R2 start")
data = await stream.read(6)
results.append(data)
results.append("R2 done")
reader2_finished.set()
async def writer2():
nonlocal results
await reader2_finished.wait()
results.append("W2 start")
await stream.write(b"write2")
results.append("W2 done")
# Execute sequence
nursery.start_soon(reader1)
nursery.start_soon(writer1)
nursery.start_soon(reader2)
nursery.start_soon(writer2)
await send_chan.send(b"data1")
# Verify sequence and that write was called
assert patched_write.call_count == 2
assert results == [
"R1 start",
b"data1",
"R1 done",
"W1 start",
"W1 done",
"R2 start",
b"write1",
"R2 done",
"W2 start",
"W2 done",
]
# ===============================================
# 2. Tests for Reset, EOF, and Close Interactions
# ===============================================
@pytest.mark.trio
async def test_read_after_remote_close_triggers_eof(mplex_stream):
"""Verify reading from a remotely closed stream returns EOF correctly."""
stream, send_chan, _ = mplex_stream
# Send some data that can be read first
await send_chan.send(b"data")
# Close the channel to signify no more data will ever arrive
await send_chan.aclose()
# Mark the stream as remotely closed
stream.event_remote_closed.set()
# The first read should succeed, consuming the buffered data
data = await stream.read(4)
assert data == b"data"
# Now that the buffer is empty and the channel is closed, this should raise EOF
with pytest.raises(MplexStreamEOF):
await stream.read(1)
@pytest.mark.trio
async def test_read_on_closed_stream_raises_eof(mplex_stream):
"""Test that reading from a closed stream with no data raises EOF."""
stream, send_chan, _ = mplex_stream
stream.event_remote_closed.set()
await send_chan.aclose() # Ensure the channel is closed
# Reading from a stream that is closed and has no data should raise EOF
with pytest.raises(MplexStreamEOF):
await stream.read(100)
@pytest.mark.trio
async def test_write_to_locally_closed_stream_raises(mplex_stream):
"""Verify writing to a locally closed stream raises MplexStreamClosed."""
stream, _, _ = mplex_stream
stream.event_local_closed.set()
with pytest.raises(MplexStreamClosed):
await stream.write(b"this should fail")
@pytest.mark.trio
async def test_read_from_reset_stream_raises(mplex_stream):
"""Verify reading from a reset stream raises MplexStreamReset."""
stream, _, _ = mplex_stream
stream.event_reset.set()
with pytest.raises(MplexStreamReset):
await stream.read(10)
@pytest.mark.trio
async def test_write_to_reset_stream_raises(mplex_stream):
"""Verify writing to a reset stream raises MplexStreamClosed."""
stream, _, _ = mplex_stream
# A stream reset implies it's also locally closed.
await stream.reset()
# The `write` method checks `event_local_closed`, which `reset` sets.
with pytest.raises(MplexStreamClosed):
await stream.write(b"this should also fail")
@pytest.mark.trio
async def test_stream_reset_cleans_up_resources(mplex_stream):
"""Verify reset() cleans up stream state and resources."""
stream, _, muxed_conn = mplex_stream
stream_id = stream.stream_id
assert stream_id in muxed_conn.streams
await stream.reset()
assert stream.event_reset.is_set()
assert stream.event_local_closed.is_set()
assert stream.event_remote_closed.is_set()
assert (HeaderTags.ResetInitiator, None, stream_id) in muxed_conn.sent_messages
assert stream_id not in muxed_conn.streams
# Verify the underlying data channel is closed
with pytest.raises(trio.ClosedResourceError):
await stream.incoming_data_channel.receive()
# ===============================================
# 3. Rigorous Concurrency Tests with Events
# ===============================================
@pytest.mark.trio
async def test_writer_is_blocked_by_reader_using_events(mplex_stream):
"""Verify a writer must wait for a reader using trio.Event for synchronization."""
stream, _, _ = mplex_stream
reader_has_lock = trio.Event()
writer_finished = trio.Event()
async def reader():
async with stream.rw_lock.read_lock():
reader_has_lock.set()
# Hold the lock until the writer has finished its attempt
await writer_finished.wait()
async def writer():
await reader_has_lock.wait()
# This call will now block until the reader releases the lock
await stream.write(b"data")
writer_finished.set()
async with trio.open_nursery() as nursery:
nursery.start_soon(reader)
nursery.start_soon(writer)
# Verify writer is blocked
await wait_all_tasks_blocked()
assert not writer_finished.is_set()
# Signal the reader to finish
writer_finished.set()
@pytest.mark.trio
async def test_multiple_readers_can_read_concurrently_using_events(mplex_stream):
"""Verify that multiple readers can acquire a read lock simultaneously."""
stream, _, _ = mplex_stream
counters = {"readers_in_critical_section": 0}
lock = trio.Lock() # To safely mutate the counter
reader1_acquired = trio.Event()
reader2_acquired = trio.Event()
all_readers_finished = trio.Event()
async def concurrent_reader(event_to_set: trio.Event):
async with stream.rw_lock.read_lock():
async with lock:
counters["readers_in_critical_section"] += 1
event_to_set.set()
# Wait until all readers have finished before exiting the lock context
await all_readers_finished.wait()
async with lock:
counters["readers_in_critical_section"] -= 1
async with trio.open_nursery() as nursery:
nursery.start_soon(concurrent_reader, reader1_acquired)
nursery.start_soon(concurrent_reader, reader2_acquired)
# Wait for both readers to acquire their locks
await reader1_acquired.wait()
await reader2_acquired.wait()
# Check that both were in the critical section at the same time
async with lock:
assert counters["readers_in_critical_section"] == 2
# Signal for all readers to finish
all_readers_finished.set()
# Verify they exit cleanly
await wait_all_tasks_blocked()
async with lock:
assert counters["readers_in_critical_section"] == 0

View File

@ -1 +0,0 @@
"""Discovery tests for py-libp2p."""

View File

@ -1 +0,0 @@
"""Bootstrap discovery tests for py-libp2p."""

View File

@ -1,52 +0,0 @@
#!/usr/bin/env python3
"""
Test the full bootstrap discovery integration
"""
import secrets
import pytest
from libp2p import new_host
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.host.basic_host import BasicHost
@pytest.mark.trio
async def test_bootstrap_integration():
"""Test bootstrap integration with new_host"""
# Test bootstrap addresses
bootstrap_addrs = [
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SznbYGzPwp8qDrq",
"/ip4/104.236.179.241/tcp/4001/p2p/QmSoLPppuBtQSGwKDZT2M73ULpjvfd3aZ6ha4oFGL1KrGM",
]
# Generate key pair
secret = secrets.token_bytes(32)
key_pair = create_new_key_pair(secret)
# Create host with bootstrap
host = new_host(key_pair=key_pair, bootstrap=bootstrap_addrs)
# Verify bootstrap discovery is set up (cast to BasicHost for type checking)
assert isinstance(host, BasicHost), "Host should be a BasicHost instance"
assert hasattr(host, "bootstrap"), "Host should have bootstrap attribute"
assert host.bootstrap is not None, "Bootstrap discovery should be initialized"
assert len(host.bootstrap.bootstrap_addrs) == len(bootstrap_addrs), (
"Bootstrap addresses should match"
)
def test_bootstrap_no_addresses():
"""Test that bootstrap is not initialized when no addresses provided"""
secret = secrets.token_bytes(32)
key_pair = create_new_key_pair(secret)
# Create host without bootstrap
host = new_host(key_pair=key_pair)
# Verify bootstrap is not initialized
assert isinstance(host, BasicHost)
assert not hasattr(host, "bootstrap") or host.bootstrap is None, (
"Bootstrap should not be initialized when no addresses provided"
)

View File

@ -1,39 +0,0 @@
#!/usr/bin/env python3
"""
Test bootstrap address validation
"""
from libp2p.discovery.bootstrap.utils import (
parse_bootstrap_peer_info,
validate_bootstrap_addresses,
)
def test_validate_addresses():
"""Test validation with a mix of valid and invalid addresses in one list."""
addresses = [
# Valid - using proper peer IDs
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ",
"/ip4/104.236.179.241/tcp/4001/p2p/QmSoLPppuBtQSGwKDZT2M73ULpjvfd3aZ6ha4oFGL1KrGM",
# Invalid
"invalid-address",
"/ip4/192.168.1.1/tcp/4001", # Missing p2p part
"", # Empty
"/ip4/127.0.0.1/tcp/4001/p2p/InvalidPeerID", # Bad peer ID
]
valid_expected = [
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ",
"/ip4/104.236.179.241/tcp/4001/p2p/QmSoLPppuBtQSGwKDZT2M73ULpjvfd3aZ6ha4oFGL1KrGM",
]
validated = validate_bootstrap_addresses(addresses)
assert validated == valid_expected, (
f"Expected only valid addresses, got: {validated}"
)
for addr in addresses:
peer_info = parse_bootstrap_peer_info(addr)
if addr in valid_expected:
assert peer_info is not None and peer_info.peer_id is not None, (
f"Should parse valid address: {addr}"
)
else:
assert peer_info is None, f"Should not parse invalid address: {addr}"