diff --git a/examples/identify/identify.py b/examples/identify/identify.py index 4882d2c3..38fe9574 100644 --- a/examples/identify/identify.py +++ b/examples/identify/identify.py @@ -1,6 +1,7 @@ import argparse import base64 import logging +import sys import multiaddr import trio @@ -72,14 +73,52 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/") format_name = "length-prefixed" if use_varint_format else "raw protobuf" + format_flag = "--raw-format" if not use_varint_format else "" print( f"First host listening (using {format_name} format). " f"Run this from another console:\n\n" - f"identify-demo " - f"-d {client_addr}\n" + f"identify-demo {format_flag} -d {client_addr}\n" ) print("Waiting for incoming identify request...") - await trio.sleep_forever() + + # Add a custom handler to show connection events + async def custom_identify_handler(stream): + peer_id = stream.muxed_conn.peer_id + print(f"\nšŸ”— Received identify request from peer: {peer_id}") + + # Show remote address in multiaddr format + try: + from libp2p.identity.identify.identify import ( + _remote_address_to_multiaddr, + ) + + remote_address = stream.get_remote_address() + if remote_address: + observed_multiaddr = _remote_address_to_multiaddr( + remote_address + ) + # Add the peer ID to create a complete multiaddr + complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}" + print(f" Remote address: {complete_multiaddr}") + else: + print(f" Remote address: {remote_address}") + except Exception: + print(f" Remote address: {stream.get_remote_address()}") + + # Call the original handler + await identify_handler(stream) + + print(f"āœ… Successfully processed identify request from {peer_id}") + + # Replace the handler with our custom one + host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler) + + try: + await trio.sleep_forever() + except KeyboardInterrupt: + print("\nšŸ›‘ Shutting down listener...") + logger.info("Listener interrupted by user") + return else: # Create second host (dialer) @@ -93,25 +132,74 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No info = info_from_p2p_addr(maddr) print(f"Second host connecting to peer: {info.peer_id}") - await host_b.connect(info) + try: + await host_b.connect(info) + except Exception as e: + error_msg = str(e) + if "unable to connect" in error_msg or "SwarmException" in error_msg: + print(f"\nāŒ Cannot connect to peer: {info.peer_id}") + print(f" Address: {destination}") + print(f" Error: {error_msg}") + print( + "\nšŸ’” Make sure the peer is running and the address is correct." + ) + return + else: + # Re-raise other exceptions + raise + stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,)) try: print("Starting identify protocol...") - # Read the complete response (could be either format) - # Read a larger chunk to get all the data before stream closes - response = await stream.read(8192) # Read enough data in one go + # Read the response using the utility function + from libp2p.utils.varint import read_length_prefixed_protobuf + + response = await read_length_prefixed_protobuf( + stream, use_varint_format + ) + full_response = response await stream.close() # Parse the response using the robust protocol-level function # This handles both old and new formats automatically - identify_msg = parse_identify_response(response) + identify_msg = parse_identify_response(full_response) print_identify_response(identify_msg) except Exception as e: - print(f"Identify protocol error: {e}") + error_msg = str(e) + print(f"Identify protocol error: {error_msg}") + + # Check for specific format mismatch errors + if "Error parsing message" in error_msg or "DecodeError" in error_msg: + print("\n" + "=" * 60) + print("FORMAT MISMATCH DETECTED!") + print("=" * 60) + if use_varint_format: + print( + "You are using length-prefixed format (default) but the " + "listener" + ) + print("is using raw protobuf format.") + print( + "\nTo fix this, run the dialer with the --raw-format flag:" + ) + print(f"identify-demo --raw-format -d {destination}") + else: + print("You are using raw protobuf format but the listener") + print("is using length-prefixed format (default).") + print( + "\nTo fix this, run the dialer without the --raw-format " + "flag:" + ) + print(f"identify-demo -d {destination}") + print("=" * 60) + else: + import traceback + + traceback.print_exc() return @@ -147,6 +235,7 @@ def main() -> None: "length-prefixed (new format)" ), ) + args = parser.parse_args() # Determine format: raw format if --raw-format is specified, otherwise @@ -154,9 +243,19 @@ def main() -> None: use_varint_format = not args.raw_format try: - trio.run(run, *(args.port, args.destination, use_varint_format)) + if args.destination: + # Run in dialer mode + trio.run(run, *(args.port, args.destination, use_varint_format)) + else: + # Run in listener mode + trio.run(run, *(args.port, args.destination, use_varint_format)) except KeyboardInterrupt: - pass + print("\nšŸ‘‹ Goodbye!") + logger.info("Application interrupted by user") + except Exception as e: + print(f"\nāŒ Error: {str(e)}") + logger.error("Error: %s", str(e)) + sys.exit(1) if __name__ == "__main__": diff --git a/examples/identify_push/identify_push_demo.py b/examples/identify_push/identify_push_demo.py index ef34fcc7..5a293f07 100644 --- a/examples/identify_push/identify_push_demo.py +++ b/examples/identify_push/identify_push_demo.py @@ -11,23 +11,26 @@ This example shows how to: import logging +import multiaddr import trio from libp2p import ( new_host, ) +from libp2p.abc import ( + INetStream, +) from libp2p.crypto.secp256k1 import ( create_new_key_pair, ) from libp2p.custom_types import ( TProtocol, ) -from libp2p.identity.identify import ( - identify_handler_for, +from libp2p.identity.identify.pb.identify_pb2 import ( + Identify, ) from libp2p.identity.identify_push import ( ID_PUSH, - identify_push_handler_for, push_identify_to_peer, ) from libp2p.peer.peerinfo import ( @@ -38,8 +41,145 @@ from libp2p.peer.peerinfo import ( logger = logging.getLogger(__name__) +def create_custom_identify_handler(host, host_name: str): + """Create a custom identify handler that displays received information.""" + + async def handle_identify(stream: INetStream) -> None: + peer_id = stream.muxed_conn.peer_id + print(f"\nšŸ” {host_name} received identify request from peer: {peer_id}") + + # Get the standard identify response using the existing function + from libp2p.identity.identify.identify import ( + _mk_identify_protobuf, + _remote_address_to_multiaddr, + ) + + # Get observed address + observed_multiaddr = None + try: + remote_address = stream.get_remote_address() + if remote_address: + observed_multiaddr = _remote_address_to_multiaddr(remote_address) + except Exception: + pass + + # Build the identify protobuf + identify_msg = _mk_identify_protobuf(host, observed_multiaddr) + response_data = identify_msg.SerializeToString() + + print(f" šŸ“‹ {host_name} identify information:") + if identify_msg.HasField("protocol_version"): + print(f" Protocol Version: {identify_msg.protocol_version}") + if identify_msg.HasField("agent_version"): + print(f" Agent Version: {identify_msg.agent_version}") + if identify_msg.HasField("public_key"): + print(f" Public Key: {identify_msg.public_key.hex()[:16]}...") + if identify_msg.listen_addrs: + print(" Listen Addresses:") + for addr_bytes in identify_msg.listen_addrs: + addr = multiaddr.Multiaddr(addr_bytes) + print(f" - {addr}") + if identify_msg.protocols: + print(" Supported Protocols:") + for protocol in identify_msg.protocols: + print(f" - {protocol}") + + # Send the response + await stream.write(response_data) + await stream.close() + + return handle_identify + + +def create_custom_identify_push_handler(host, host_name: str): + """Create a custom identify/push handler that displays received information.""" + + async def handle_identify_push(stream: INetStream) -> None: + peer_id = stream.muxed_conn.peer_id + print(f"\nšŸ“¤ {host_name} received identify/push from peer: {peer_id}") + + try: + # Read the identify message using the utility function + from libp2p.utils.varint import read_length_prefixed_protobuf + + data = await read_length_prefixed_protobuf(stream, use_varint_format=True) + + # Parse the identify message + identify_msg = Identify() + identify_msg.ParseFromString(data) + + print(" šŸ“‹ Received identify information:") + if identify_msg.HasField("protocol_version"): + print(f" Protocol Version: {identify_msg.protocol_version}") + if identify_msg.HasField("agent_version"): + print(f" Agent Version: {identify_msg.agent_version}") + if identify_msg.HasField("public_key"): + print(f" Public Key: {identify_msg.public_key.hex()[:16]}...") + if identify_msg.HasField("observed_addr") and identify_msg.observed_addr: + observed_addr = multiaddr.Multiaddr(identify_msg.observed_addr) + print(f" Observed Address: {observed_addr}") + if identify_msg.listen_addrs: + print(" Listen Addresses:") + for addr_bytes in identify_msg.listen_addrs: + addr = multiaddr.Multiaddr(addr_bytes) + print(f" - {addr}") + if identify_msg.protocols: + print(" Supported Protocols:") + for protocol in identify_msg.protocols: + print(f" - {protocol}") + + # Update the peerstore with the new information + from libp2p.identity.identify_push.identify_push import ( + _update_peerstore_from_identify, + ) + + await _update_peerstore_from_identify( + host.get_peerstore(), peer_id, identify_msg + ) + + print(f" āœ… {host_name} updated peerstore with new information") + + except Exception as e: + print(f" āŒ Error processing identify/push: {e}") + finally: + await stream.close() + + return handle_identify_push + + +async def display_peerstore_info(host, host_name: str, peer_id, description: str): + """Display peerstore information for a specific peer.""" + peerstore = host.get_peerstore() + + try: + addrs = peerstore.addrs(peer_id) + except Exception: + addrs = [] + + try: + protocols = peerstore.get_protocols(peer_id) + except Exception: + protocols = [] + + print(f"\nšŸ“š {host_name} peerstore for {description}:") + print(f" Peer ID: {peer_id}") + if addrs: + print(" Addresses:") + for addr in addrs: + print(f" - {addr}") + else: + print(" Addresses: None") + + if protocols: + print(" Protocols:") + for protocol in protocols: + print(f" - {protocol}") + else: + print(" Protocols: None") + + async def main() -> None: - print("\n==== Starting Identify-Push Example ====\n") + print("\n==== Starting Enhanced Identify-Push Example ====\n") # Create key pairs for the two hosts key_pair_1 = create_new_key_pair() @@ -48,45 +188,49 @@ async def main() -> None: # Create the first host host_1 = new_host(key_pair=key_pair_1) - # Set up the identify and identify/push handlers - host_1.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_1)) - host_1.set_stream_handler(ID_PUSH, identify_push_handler_for(host_1)) + # Set up custom identify and identify/push handlers + host_1.set_stream_handler( + TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_1, "Host 1") + ) + host_1.set_stream_handler( + ID_PUSH, create_custom_identify_push_handler(host_1, "Host 1") + ) # Create the second host host_2 = new_host(key_pair=key_pair_2) - # Set up the identify and identify/push handlers - host_2.set_stream_handler(TProtocol("/ipfs/id/1.0.0"), identify_handler_for(host_2)) - host_2.set_stream_handler(ID_PUSH, identify_push_handler_for(host_2)) + # Set up custom identify and identify/push handlers + host_2.set_stream_handler( + TProtocol("/ipfs/id/1.0.0"), create_custom_identify_handler(host_2, "Host 2") + ) + host_2.set_stream_handler( + ID_PUSH, create_custom_identify_push_handler(host_2, "Host 2") + ) # Start listening on random ports using the run context manager - import multiaddr - listen_addr_1 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0") listen_addr_2 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0") async with host_1.run([listen_addr_1]), host_2.run([listen_addr_2]): # Get the addresses of both hosts addr_1 = host_1.get_addrs()[0] - logger.info(f"Host 1 listening on {addr_1}") - print(f"Host 1 listening on {addr_1}") - print(f"Peer ID: {host_1.get_id().pretty()}") - addr_2 = host_2.get_addrs()[0] - logger.info(f"Host 2 listening on {addr_2}") - print(f"Host 2 listening on {addr_2}") - print(f"Peer ID: {host_2.get_id().pretty()}") - print("\nConnecting Host 2 to Host 1...") + print("šŸ  Host Configuration:") + print(f" Host 1: {addr_1}") + print(f" Host 1 Peer ID: {host_1.get_id().pretty()}") + print(f" Host 2: {addr_2}") + print(f" Host 2 Peer ID: {host_2.get_id().pretty()}") + + print("\nšŸ”— Connecting Host 2 to Host 1...") # Connect host_2 to host_1 peer_info = info_from_p2p_addr(addr_1) await host_2.connect(peer_info) - logger.info("Host 2 connected to Host 1") - print("Host 2 successfully connected to Host 1") + print("āœ… Host 2 successfully connected to Host 1") # Run the identify protocol from host_2 to host_1 - # (so Host 1 learns Host 2's address) + print("\nšŸ”„ Running identify protocol (Host 2 → Host 1)...") from libp2p.identity.identify.identify import ID as IDENTIFY_PROTOCOL_ID stream = await host_2.new_stream(host_1.get_id(), (IDENTIFY_PROTOCOL_ID,)) @@ -94,64 +238,58 @@ async def main() -> None: await stream.close() # Run the identify protocol from host_1 to host_2 - # (so Host 2 learns Host 1's address) + print("\nšŸ”„ Running identify protocol (Host 1 → Host 2)...") stream = await host_1.new_stream(host_2.get_id(), (IDENTIFY_PROTOCOL_ID,)) response = await stream.read() await stream.close() - # --- NEW CODE: Update Host 1's peerstore with Host 2's addresses --- - from libp2p.identity.identify.pb.identify_pb2 import ( - Identify, - ) - + # Update Host 1's peerstore with Host 2's addresses identify_msg = Identify() identify_msg.ParseFromString(response) peerstore_1 = host_1.get_peerstore() peer_id_2 = host_2.get_id() for addr_bytes in identify_msg.listen_addrs: maddr = multiaddr.Multiaddr(addr_bytes) - # TTL can be any positive int - peerstore_1.add_addr( - peer_id_2, - maddr, - ttl=3600, - ) - # --- END NEW CODE --- + peerstore_1.add_addr(peer_id_2, maddr, ttl=3600) - # Now Host 1's peerstore should have Host 2's address - peerstore_1 = host_1.get_peerstore() - peer_id_2 = host_2.get_id() - addrs_1_for_2 = peerstore_1.addrs(peer_id_2) - logger.info( - f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: " - f"{addrs_1_for_2}" - ) - print( - f"[DEBUG] Host 1 peerstore addresses for Host 2 before push: " - f"{addrs_1_for_2}" + # Display peerstore information before push + await display_peerstore_info( + host_1, "Host 1", peer_id_2, "Host 2 (before push)" ) # Push identify information from host_1 to host_2 - logger.info("Host 1 pushing identify information to Host 2") - print("\nHost 1 pushing identify information to Host 2...") + print("\nšŸ“¤ Host 1 pushing identify information to Host 2...") try: - # Call push_identify_to_peer which now returns a boolean success = await push_identify_to_peer(host_1, host_2.get_id()) if success: - logger.info("Identify push completed successfully") - print("Identify push completed successfully!") + print("āœ… Identify push completed successfully!") else: - logger.warning("Identify push didn't complete successfully") - print("\nWarning: Identify push didn't complete successfully") + print("āš ļø Identify push didn't complete successfully") except Exception as e: - logger.error(f"Error during identify push: {str(e)}") - print(f"\nError during identify push: {str(e)}") + print(f"āŒ Error during identify push: {str(e)}") - # Add this at the end of your async with block: - await trio.sleep(0.5) # Give background tasks time to finish + # Give a moment for the identify/push processing to complete + await trio.sleep(0.5) + + # Display peerstore information after push + await display_peerstore_info(host_1, "Host 1", peer_id_2, "Host 2 (after push)") + await display_peerstore_info( + host_2, "Host 2", host_1.get_id(), "Host 1 (after push)" + ) + + # Give more time for background tasks to finish and connections to stabilize + print("\nā³ Waiting for background tasks to complete...") + await trio.sleep(1.0) + + # Gracefully close connections to prevent connection errors + print("šŸ”Œ Closing connections...") + await host_2.disconnect(host_1.get_id()) + await trio.sleep(0.2) + + print("\nšŸŽ‰ Example completed successfully!") if __name__ == "__main__": diff --git a/examples/identify_push/identify_push_listener_dialer.py b/examples/identify_push/identify_push_listener_dialer.py index 0e573e0b..c23e62bb 100644 --- a/examples/identify_push/identify_push_listener_dialer.py +++ b/examples/identify_push/identify_push_listener_dialer.py @@ -41,6 +41,9 @@ from libp2p.identity.identify import ( ID as ID_IDENTIFY, identify_handler_for, ) +from libp2p.identity.identify.identify import ( + _remote_address_to_multiaddr, +) from libp2p.identity.identify.pb.identify_pb2 import ( Identify, ) @@ -72,40 +75,30 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True): async def handle_identify_push(stream: INetStream) -> None: peer_id = stream.muxed_conn.peer_id + # Get remote address information try: - if use_varint_format: - # Read length-prefixed identify message from the stream - from libp2p.utils.varint import decode_varint_from_bytes + remote_address = stream.get_remote_address() + if remote_address: + observed_multiaddr = _remote_address_to_multiaddr(remote_address) + logger.info( + "Connection from remote peer %s, address: %s, multiaddr: %s", + peer_id, + remote_address, + observed_multiaddr, + ) + print(f"\nšŸ”— Received identify/push request from peer: {peer_id}") + # Add the peer ID to create a complete multiaddr + complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}" + print(f" Remote address: {complete_multiaddr}") + except Exception as e: + logger.error("Error getting remote address: %s", e) + print(f"\nšŸ”— Received identify/push request from peer: {peer_id}") - # First read the varint length prefix - length_bytes = b"" - while True: - b = await stream.read(1) - if not b: - break - length_bytes += b - if b[0] & 0x80 == 0: - break + try: + # Use the utility function to read the protobuf message + from libp2p.utils.varint import read_length_prefixed_protobuf - if not length_bytes: - logger.warning("No length prefix received from peer %s", peer_id) - return - - msg_length = decode_varint_from_bytes(length_bytes) - - # Read the protobuf message - data = await stream.read(msg_length) - if len(data) != msg_length: - logger.warning("Incomplete message received from peer %s", peer_id) - return - else: - # Read raw protobuf message from the stream - data = b"" - while True: - chunk = await stream.read(4096) - if not chunk: - break - data += chunk + data = await read_length_prefixed_protobuf(stream, use_varint_format) identify_msg = Identify() identify_msg.ParseFromString(data) @@ -155,11 +148,41 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True): await _update_peerstore_from_identify(peerstore, peer_id, identify_msg) logger.info("Successfully processed identify/push from peer %s", peer_id) - print(f"\nSuccessfully processed identify/push from peer {peer_id}") + print(f"āœ… Successfully processed identify/push from peer {peer_id}") except Exception as e: - logger.error("Error processing identify/push from %s: %s", peer_id, e) - print(f"\nError processing identify/push from {peer_id}: {e}") + error_msg = str(e) + logger.error( + "Error processing identify/push from %s: %s", peer_id, error_msg + ) + print(f"\nError processing identify/push from {peer_id}: {error_msg}") + + # Check for specific format mismatch errors + if ( + "Error parsing message" in error_msg + or "DecodeError" in error_msg + or "ParseFromString" in error_msg + ): + print("\n" + "=" * 60) + print("FORMAT MISMATCH DETECTED!") + print("=" * 60) + if use_varint_format: + print( + "You are using length-prefixed format (default) but the " + "dialer is using raw protobuf format." + ) + print("\nTo fix this, run the dialer with the --raw-format flag:") + print( + "identify-push-listener-dialer-demo --raw-format -d
" + ) + else: + print("You are using raw protobuf format but the dialer") + print("is using length-prefixed format (default).") + print( + "\nTo fix this, run the dialer without the --raw-format flag:" + ) + print("identify-push-listener-dialer-demo -d
") + print("=" * 60) finally: # Close the stream after processing await stream.close() @@ -167,7 +190,9 @@ def custom_identify_push_handler_for(host, use_varint_format: bool = True): return handle_identify_push -async def run_listener(port: int, use_varint_format: bool = True) -> None: +async def run_listener( + port: int, use_varint_format: bool = True, raw_format_flag: bool = False +) -> None: """Run a host in listener mode.""" format_name = "length-prefixed" if use_varint_format else "raw protobuf" print( @@ -187,29 +212,41 @@ async def run_listener(port: int, use_varint_format: bool = True) -> None: ) host.set_stream_handler( ID_IDENTIFY_PUSH, - identify_push_handler_for(host, use_varint_format=use_varint_format), + custom_identify_push_handler_for(host, use_varint_format=use_varint_format), ) # Start listening listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") - async with host.run([listen_addr]): - addr = host.get_addrs()[0] - logger.info("Listener host ready!") - print("Listener host ready!") + try: + async with host.run([listen_addr]): + addr = host.get_addrs()[0] + logger.info("Listener host ready!") + print("Listener host ready!") - logger.info(f"Listening on: {addr}") - print(f"Listening on: {addr}") + logger.info(f"Listening on: {addr}") + print(f"Listening on: {addr}") - logger.info(f"Peer ID: {host.get_id().pretty()}") - print(f"Peer ID: {host.get_id().pretty()}") + logger.info(f"Peer ID: {host.get_id().pretty()}") + print(f"Peer ID: {host.get_id().pretty()}") - print("\nRun dialer with command:") - print(f"identify-push-listener-dialer-demo -d {addr}") - print("\nWaiting for incoming connections... (Ctrl+C to exit)") + print("\nRun dialer with command:") + if raw_format_flag: + print(f"identify-push-listener-dialer-demo -d {addr} --raw-format") + else: + print(f"identify-push-listener-dialer-demo -d {addr}") + print("\nWaiting for incoming identify/push requests... (Ctrl+C to exit)") - # Keep running until interrupted - await trio.sleep_forever() + # Keep running until interrupted + try: + await trio.sleep_forever() + except KeyboardInterrupt: + print("\nšŸ›‘ Shutting down listener...") + logger.info("Listener interrupted by user") + return + except Exception as e: + logger.error(f"Listener error: {e}") + raise async def run_dialer( @@ -256,7 +293,9 @@ async def run_dialer( try: await host.connect(peer_info) logger.info("Successfully connected to listener!") - print("Successfully connected to listener!") + print("āœ… Successfully connected to listener!") + print(f" Connected to: {peer_info.peer_id}") + print(f" Full address: {destination}") # Push identify information to the listener logger.info("Pushing identify information to listener...") @@ -270,7 +309,7 @@ async def run_dialer( if success: logger.info("Identify push completed successfully!") - print("Identify push completed successfully!") + print("āœ… Identify push completed successfully!") logger.info("Example completed successfully!") print("\nExample completed successfully!") @@ -281,17 +320,57 @@ async def run_dialer( logger.warning("Example completed with warnings.") print("Example completed with warnings.") except Exception as e: - logger.error(f"Error during identify push: {str(e)}") - print(f"\nError during identify push: {str(e)}") + error_msg = str(e) + logger.error(f"Error during identify push: {error_msg}") + print(f"\nError during identify push: {error_msg}") + + # Check for specific format mismatch errors + if ( + "Error parsing message" in error_msg + or "DecodeError" in error_msg + or "ParseFromString" in error_msg + ): + print("\n" + "=" * 60) + print("FORMAT MISMATCH DETECTED!") + print("=" * 60) + if use_varint_format: + print( + "You are using length-prefixed format (default) but the " + "listener is using raw protobuf format." + ) + print( + "\nTo fix this, run the dialer with the --raw-format flag:" + ) + print( + f"identify-push-listener-dialer-demo --raw-format -d " + f"{destination}" + ) + else: + print("You are using raw protobuf format but the listener") + print("is using length-prefixed format (default).") + print( + "\nTo fix this, run the dialer without the --raw-format " + "flag:" + ) + print(f"identify-push-listener-dialer-demo -d {destination}") + print("=" * 60) logger.error("Example completed with errors.") print("Example completed with errors.") # Continue execution despite the push error except Exception as e: - logger.error(f"Error during dialer operation: {str(e)}") - print(f"\nError during dialer operation: {str(e)}") - raise + error_msg = str(e) + if "unable to connect" in error_msg or "SwarmException" in error_msg: + print(f"\nāŒ Cannot connect to peer: {peer_info.peer_id}") + print(f" Address: {destination}") + print(f" Error: {error_msg}") + print("\nšŸ’” Make sure the peer is running and the address is correct.") + return + else: + logger.error(f"Error during dialer operation: {error_msg}") + print(f"\nError during dialer operation: {error_msg}") + raise def main() -> None: @@ -301,12 +380,21 @@ def main() -> None: Without arguments, it runs as a listener on random port. With -d parameter, it runs as a dialer on random port. + Port 0 (default) means the OS will automatically assign an available port. + This prevents port conflicts when running multiple instances. + Use --raw-format to send raw protobuf messages (old format) instead of length-prefixed protobuf messages (new format, default). """ parser = argparse.ArgumentParser(description=description) - parser.add_argument("-p", "--port", default=0, type=int, help="source port number") + parser.add_argument( + "-p", + "--port", + default=0, + type=int, + help="source port number (0 = random available port)", + ) parser.add_argument( "-d", "--destination", @@ -321,6 +409,7 @@ def main() -> None: "length-prefixed (new format)" ), ) + args = parser.parse_args() # Determine format: raw format if --raw-format is specified, otherwise @@ -333,12 +422,12 @@ def main() -> None: trio.run(run_dialer, args.port, args.destination, use_varint_format) else: # Run in listener mode with random available port if not specified - trio.run(run_listener, args.port, use_varint_format) + trio.run(run_listener, args.port, use_varint_format, args.raw_format) except KeyboardInterrupt: - print("\nInterrupted by user") - logger.info("Interrupted by user") + print("\nšŸ‘‹ Goodbye!") + logger.info("Application interrupted by user") except Exception as e: - print(f"\nError: {str(e)}") + print(f"\nāŒ Error: {str(e)}") logger.error("Error: %s", str(e)) sys.exit(1) diff --git a/libp2p/identity/identify/identify.py b/libp2p/identity/identify/identify.py index 1e38d566..04f9efed 100644 --- a/libp2p/identity/identify/identify.py +++ b/libp2p/identity/identify/identify.py @@ -113,7 +113,7 @@ def parse_identify_response(response: bytes) -> Identify: def identify_handler_for( - host: IHost, use_varint_format: bool = False + host: IHost, use_varint_format: bool = True ) -> StreamHandlerFn: async def handle_identify(stream: INetStream) -> None: # get observed address from ``stream`` diff --git a/libp2p/identity/identify_push/identify_push.py b/libp2p/identity/identify_push/identify_push.py index f13bd970..fec62089 100644 --- a/libp2p/identity/identify_push/identify_push.py +++ b/libp2p/identity/identify_push/identify_push.py @@ -28,7 +28,7 @@ from libp2p.utils import ( varint, ) from libp2p.utils.varint import ( - decode_varint_from_bytes, + read_length_prefixed_protobuf, ) from ..identify.identify import ( @@ -66,49 +66,8 @@ def identify_push_handler_for( peer_id = stream.muxed_conn.peer_id try: - if use_varint_format: - # Read length-prefixed identify message from the stream - # First read the varint length prefix - length_bytes = b"" - while True: - b = await stream.read(1) - if not b: - break - length_bytes += b - if b[0] & 0x80 == 0: - break - - if not length_bytes: - logger.warning("No length prefix received from peer %s", peer_id) - return - - msg_length = decode_varint_from_bytes(length_bytes) - - # Read the protobuf message - data = await stream.read(msg_length) - if len(data) != msg_length: - logger.warning("Incomplete message received from peer %s", peer_id) - return - else: - # Read raw protobuf message from the stream - # For raw format, we need to read all data before the stream is closed - data = b"" - try: - # Read all available data in a single operation - data = await stream.read() - except StreamClosed: - # Try to read any remaining data - try: - data = await stream.read() - except Exception: - pass - - # If we got no data, log a warning and return - if not data: - logger.warning( - "No data received in raw format from peer %s", peer_id - ) - return + # Use the utility function to read the protobuf message + data = await read_length_prefixed_protobuf(stream, use_varint_format) identify_msg = Identify() identify_msg.ParseFromString(data) @@ -119,6 +78,11 @@ def identify_push_handler_for( ) logger.debug("Successfully processed identify/push from peer %s", peer_id) + + # Send acknowledgment to indicate successful processing + # This ensures the sender knows the message was received before closing + await stream.write(b"OK") + except StreamClosed: logger.debug( "Stream closed while processing identify/push from %s", peer_id @@ -127,7 +91,10 @@ def identify_push_handler_for( logger.error("Error processing identify/push from %s: %s", peer_id, e) finally: # Close the stream after processing - await stream.close() + try: + await stream.close() + except Exception: + pass # Ignore errors when closing return handle_identify_push @@ -226,7 +193,20 @@ async def push_identify_to_peer( # Send raw protobuf message await stream.write(response) - # Close the stream + # Wait for acknowledgment from the receiver with timeout + # This ensures the message was processed before closing + try: + with trio.move_on_after(1.0): # 1 second timeout + ack = await stream.read(2) # Read "OK" acknowledgment + if ack != b"OK": + logger.warning( + "Unexpected acknowledgment from peer %s: %s", peer_id, ack + ) + except Exception as e: + logger.debug("No acknowledgment received from peer %s: %s", peer_id, e) + # Continue anyway, as the message might have been processed + + # Close the stream after acknowledgment (or timeout) await stream.close() logger.debug("Successfully pushed identify to peer %s", peer_id) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index a913c721..5641ec5d 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -102,6 +102,9 @@ class TopicValidator(NamedTuple): is_async: bool +MAX_CONCURRENT_VALIDATORS = 10 + + class Pubsub(Service, IPubsub): host: IHost @@ -109,6 +112,7 @@ class Pubsub(Service, IPubsub): peer_receive_channel: trio.MemoryReceiveChannel[ID] dead_peer_receive_channel: trio.MemoryReceiveChannel[ID] + _validator_semaphore: trio.Semaphore seen_messages: LastSeenCache @@ -143,6 +147,7 @@ class Pubsub(Service, IPubsub): msg_id_constructor: Callable[ [rpc_pb2.Message], bytes ] = get_peer_and_seqno_msg_id, + max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS, ) -> None: """ Construct a new Pubsub object, which is responsible for handling all @@ -168,6 +173,7 @@ class Pubsub(Service, IPubsub): # Therefore, we can only close from the receive side. self.peer_receive_channel = peer_receive self.dead_peer_receive_channel = dead_peer_receive + self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count) # Register a notifee self.host.get_network().register_notifee( PubsubNotifee(peer_send, dead_peer_send) @@ -657,7 +663,11 @@ class Pubsub(Service, IPubsub): logger.debug("successfully published message %s", msg) - async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: + async def validate_msg( + self, + msg_forwarder: ID, + msg: rpc_pb2.Message, + ) -> None: """ Validate the received message. @@ -680,23 +690,34 @@ class Pubsub(Service, IPubsub): if not validator(msg_forwarder, msg): raise ValidationError(f"Validation failed for msg={msg}") - # TODO: Implement throttle on async validators - if len(async_topic_validators) > 0: # Appends to lists are thread safe in CPython - results = [] - - async def run_async_validator(func: AsyncValidatorFn) -> None: - result = await func(msg_forwarder, msg) - results.append(result) + results: list[bool] = [] async with trio.open_nursery() as nursery: for async_validator in async_topic_validators: - nursery.start_soon(run_async_validator, async_validator) + nursery.start_soon( + self._run_async_validator, + async_validator, + msg_forwarder, + msg, + results, + ) if not all(results): raise ValidationError(f"Validation failed for msg={msg}") + async def _run_async_validator( + self, + func: AsyncValidatorFn, + msg_forwarder: ID, + msg: rpc_pb2.Message, + results: list[bool], + ) -> None: + async with self._validator_semaphore: + result = await func(msg_forwarder, msg) + results.append(result) + async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: """ Push a pubsub message to others. diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index b86357e7..e8d0561d 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,3 +1,5 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from types import ( TracebackType, ) @@ -32,6 +34,72 @@ if TYPE_CHECKING: ) +class ReadWriteLock: + """ + A read-write lock that allows multiple concurrent readers + or one exclusive writer, implemented using Trio primitives. + """ + + def __init__(self) -> None: + self._readers = 0 + self._readers_lock = trio.Lock() # Protects access to _readers count + self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time + + async def acquire_read(self) -> None: + """Acquire a read lock. Multiple readers can hold it simultaneously.""" + try: + async with self._readers_lock: + if self._readers == 0: + await self._writer_lock.acquire() + self._readers += 1 + except trio.Cancelled: + raise + + async def release_read(self) -> None: + """Release a read lock.""" + async with self._readers_lock: + if self._readers == 1: + self._writer_lock.release() + self._readers -= 1 + + async def acquire_write(self) -> None: + """Acquire an exclusive write lock.""" + try: + await self._writer_lock.acquire() + except trio.Cancelled: + raise + + def release_write(self) -> None: + """Release the exclusive write lock.""" + self._writer_lock.release() + + @asynccontextmanager + async def read_lock(self) -> AsyncGenerator[None, None]: + """Context manager for acquiring and releasing a read lock safely.""" + acquire = False + try: + await self.acquire_read() + acquire = True + yield + finally: + if acquire: + with trio.CancelScope() as scope: + scope.shield = True + await self.release_read() + + @asynccontextmanager + async def write_lock(self) -> AsyncGenerator[None, None]: + """Context manager for acquiring and releasing a write lock safely.""" + acquire = False + try: + await self.acquire_write() + acquire = True + yield + finally: + if acquire: + self.release_write() + + class MplexStream(IMuxedStream): """ reference: https://github.com/libp2p/go-mplex/blob/master/stream.go @@ -46,7 +114,7 @@ class MplexStream(IMuxedStream): read_deadline: int | None write_deadline: int | None - # TODO: Add lock for read/write to avoid interleaving receiving messages? + rw_lock: ReadWriteLock close_lock: trio.Lock # NOTE: `dataIn` is size of 8 in Go implementation. @@ -80,6 +148,7 @@ class MplexStream(IMuxedStream): self.event_remote_closed = trio.Event() self.event_reset = trio.Event() self.close_lock = trio.Lock() + self.rw_lock = ReadWriteLock() self.incoming_data_channel = incoming_data_channel self._buf = bytearray() @@ -113,48 +182,49 @@ class MplexStream(IMuxedStream): :param n: number of bytes to read :return: bytes actually read """ - if n is not None and n < 0: - raise ValueError( - "the number of bytes to read `n` must be non-negative or " - f"`None` to indicate read until EOF, got n={n}" - ) - if self.event_reset.is_set(): - raise MplexStreamReset - if n is None: - return await self._read_until_eof() - if len(self._buf) == 0: - data: bytes - # Peek whether there is data available. If yes, we just read until there is - # no data, then return. - try: - data = self.incoming_data_channel.receive_nowait() - self._buf.extend(data) - except trio.EndOfChannel: - raise MplexStreamEOF - except trio.WouldBlock: - # We know `receive` will be blocked here. Wait for data here with - # `receive` and catch all kinds of errors here. + async with self.rw_lock.read_lock(): + if n is not None and n < 0: + raise ValueError( + "the number of bytes to read `n` must be non-negative or " + f"`None` to indicate read until EOF, got n={n}" + ) + if self.event_reset.is_set(): + raise MplexStreamReset + if n is None: + return await self._read_until_eof() + if len(self._buf) == 0: + data: bytes + # Peek whether there is data available. If yes, we just read until + # there is no data, then return. try: - data = await self.incoming_data_channel.receive() + data = self.incoming_data_channel.receive_nowait() self._buf.extend(data) except trio.EndOfChannel: - if self.event_reset.is_set(): - raise MplexStreamReset - if self.event_remote_closed.is_set(): - raise MplexStreamEOF - except trio.ClosedResourceError as error: - # Probably `incoming_data_channel` is closed in `reset` when we are - # waiting for `receive`. - if self.event_reset.is_set(): - raise MplexStreamReset - raise Exception( - "`incoming_data_channel` is closed but stream is not reset. " - "This should never happen." - ) from error - self._buf.extend(self._read_return_when_blocked()) - payload = self._buf[:n] - self._buf = self._buf[len(payload) :] - return bytes(payload) + raise MplexStreamEOF + except trio.WouldBlock: + # We know `receive` will be blocked here. Wait for data here with + # `receive` and catch all kinds of errors here. + try: + data = await self.incoming_data_channel.receive() + self._buf.extend(data) + except trio.EndOfChannel: + if self.event_reset.is_set(): + raise MplexStreamReset + if self.event_remote_closed.is_set(): + raise MplexStreamEOF + except trio.ClosedResourceError as error: + # Probably `incoming_data_channel` is closed in `reset` when + # we are waiting for `receive`. + if self.event_reset.is_set(): + raise MplexStreamReset + raise Exception( + "`incoming_data_channel` is closed but stream is not reset." + "This should never happen." + ) from error + self._buf.extend(self._read_return_when_blocked()) + payload = self._buf[:n] + self._buf = self._buf[len(payload) :] + return bytes(payload) async def write(self, data: bytes) -> None: """ @@ -162,14 +232,15 @@ class MplexStream(IMuxedStream): :return: number of bytes written """ - if self.event_local_closed.is_set(): - raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}") - flag = ( - HeaderTags.MessageInitiator - if self.is_initiator - else HeaderTags.MessageReceiver - ) - await self.muxed_conn.send_message(flag, data, self.stream_id) + async with self.rw_lock.write_lock(): + if self.event_local_closed.is_set(): + raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}") + flag = ( + HeaderTags.MessageInitiator + if self.is_initiator + else HeaderTags.MessageReceiver + ) + await self.muxed_conn.send_message(flag, data, self.stream_id) async def close(self) -> None: """ diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index eba0156e..b2711e1a 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -45,6 +45,9 @@ from libp2p.stream_muxer.exceptions import ( MuxedStreamReset, ) +# Configure logger for this module +logger = logging.getLogger("libp2p.stream_muxer.yamux") + PROTOCOL_ID = "/yamux/1.0.0" TYPE_DATA = 0x0 TYPE_WINDOW_UPDATE = 0x1 @@ -98,13 +101,13 @@ class YamuxStream(IMuxedStream): # Flow control: Check if we have enough send window total_len = len(data) sent = 0 - logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ") + logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ") while sent < total_len: # Wait for available window with timeout timeout = False async with self.window_lock: if self.send_window == 0: - logging.debug( + logger.debug( f"Stream {self.stream_id}: Window is zero, waiting for update" ) # Release lock and wait with timeout @@ -152,12 +155,12 @@ class YamuxStream(IMuxedStream): """ if increment <= 0: # If increment is zero or negative, skip sending update - logging.debug( + logger.debug( f"Stream {self.stream_id}: Skipping window update" f"(increment={increment})" ) return - logging.debug( + logger.debug( f"Stream {self.stream_id}: Sending window update with increment={increment}" ) @@ -185,7 +188,7 @@ class YamuxStream(IMuxedStream): # If the stream is closed for receiving and the buffer is empty, raise EOF if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id): - logging.debug( + logger.debug( f"Stream {self.stream_id}: Stream closed for receiving and buffer empty" ) raise MuxedStreamEOF("Stream is closed for receiving") @@ -198,7 +201,7 @@ class YamuxStream(IMuxedStream): # If buffer is not available, check if stream is closed if buffer is None: - logging.debug(f"Stream {self.stream_id}: No buffer available") + logger.debug(f"Stream {self.stream_id}: No buffer available") raise MuxedStreamEOF("Stream buffer closed") # If we have data in buffer, process it @@ -210,34 +213,34 @@ class YamuxStream(IMuxedStream): # Send window update for the chunk we just read async with self.window_lock: self.recv_window += len(chunk) - logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}") + logger.debug(f"Stream {self.stream_id}: Update {len(chunk)}") await self.send_window_update(len(chunk), skip_lock=True) # If stream is closed (FIN received) and buffer is empty, break if self.recv_closed and len(buffer) == 0: - logging.debug(f"Stream {self.stream_id}: Closed with empty buffer") + logger.debug(f"Stream {self.stream_id}: Closed with empty buffer") break # If stream was reset, raise reset error if self.reset_received: - logging.debug(f"Stream {self.stream_id}: Stream was reset") + logger.debug(f"Stream {self.stream_id}: Stream was reset") raise MuxedStreamReset("Stream was reset") # Wait for more data or stream closure - logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN") + logger.debug(f"Stream {self.stream_id}: Waiting for data or FIN") await self.conn.stream_events[self.stream_id].wait() self.conn.stream_events[self.stream_id] = trio.Event() # After loop exit, first check if we have data to return if data: - logging.debug( + logger.debug( f"Stream {self.stream_id}: Returning {len(data)} bytes after loop" ) return data # No data accumulated, now check why we exited the loop if self.conn.event_shutting_down.is_set(): - logging.debug(f"Stream {self.stream_id}: Connection shutting down") + logger.debug(f"Stream {self.stream_id}: Connection shutting down") raise MuxedStreamEOF("Connection shut down") # Return empty data @@ -246,7 +249,7 @@ class YamuxStream(IMuxedStream): data = await self.conn.read_stream(self.stream_id, n) async with self.window_lock: self.recv_window += len(data) - logging.debug( + logger.debug( f"Stream {self.stream_id}: Sending window update after read, " f"increment={len(data)}" ) @@ -255,7 +258,7 @@ class YamuxStream(IMuxedStream): async def close(self) -> None: if not self.send_closed: - logging.debug(f"Half-closing stream {self.stream_id} (local end)") + logger.debug(f"Half-closing stream {self.stream_id} (local end)") header = struct.pack( YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0 ) @@ -271,7 +274,7 @@ class YamuxStream(IMuxedStream): async def reset(self) -> None: if not self.closed: - logging.debug(f"Resetting stream {self.stream_id}") + logger.debug(f"Resetting stream {self.stream_id}") header = struct.pack( YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0 ) @@ -349,7 +352,7 @@ class Yamux(IMuxedConn): self._nursery: Nursery | None = None async def start(self) -> None: - logging.debug(f"Starting Yamux for {self.peer_id}") + logger.debug(f"Starting Yamux for {self.peer_id}") if self.event_started.is_set(): return async with trio.open_nursery() as nursery: @@ -362,7 +365,7 @@ class Yamux(IMuxedConn): return self.is_initiator_value async def close(self, error_code: int = GO_AWAY_NORMAL) -> None: - logging.debug(f"Closing Yamux connection with code {error_code}") + logger.debug(f"Closing Yamux connection with code {error_code}") async with self.streams_lock: if not self.event_shutting_down.is_set(): try: @@ -371,7 +374,7 @@ class Yamux(IMuxedConn): ) await self.secured_conn.write(header) except Exception as e: - logging.debug(f"Failed to send GO_AWAY: {e}") + logger.debug(f"Failed to send GO_AWAY: {e}") self.event_shutting_down.set() for stream in self.streams.values(): stream.closed = True @@ -382,12 +385,12 @@ class Yamux(IMuxedConn): self.stream_events.clear() try: await self.secured_conn.close() - logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}") + logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}") except Exception as e: - logging.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}") + logger.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}") self.event_closed.set() if self.on_close: - logging.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}") + logger.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}") if inspect.iscoroutinefunction(self.on_close): if self.on_close is not None: await self.on_close() @@ -416,7 +419,7 @@ class Yamux(IMuxedConn): header = struct.pack( YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0 ) - logging.debug(f"Sending SYN header for stream {stream_id}") + logger.debug(f"Sending SYN header for stream {stream_id}") await self.secured_conn.write(header) return stream except Exception as e: @@ -424,32 +427,32 @@ class Yamux(IMuxedConn): raise e async def accept_stream(self) -> IMuxedStream: - logging.debug("Waiting for new stream") + logger.debug("Waiting for new stream") try: stream = await self.new_stream_receive_channel.receive() - logging.debug(f"Received stream {stream.stream_id}") + logger.debug(f"Received stream {stream.stream_id}") return stream except trio.EndOfChannel: raise MuxedStreamError("No new streams available") async def read_stream(self, stream_id: int, n: int = -1) -> bytes: - logging.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}") + logger.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}") if n is None: n = -1 while True: async with self.streams_lock: if stream_id not in self.streams: - logging.debug(f"Stream {self.peer_id}:{stream_id} unknown") + logger.debug(f"Stream {self.peer_id}:{stream_id} unknown") raise MuxedStreamEOF("Stream closed") if self.event_shutting_down.is_set(): - logging.debug( + logger.debug( f"Stream {self.peer_id}:{stream_id}: connection shutting down" ) raise MuxedStreamEOF("Connection shut down") stream = self.streams[stream_id] buffer = self.stream_buffers.get(stream_id) - logging.debug( + logger.debug( f"Stream {self.peer_id}:{stream_id}: " f"closed={stream.closed}, " f"recv_closed={stream.recv_closed}, " @@ -457,7 +460,7 @@ class Yamux(IMuxedConn): f"buffer_len={len(buffer) if buffer else 0}" ) if buffer is None: - logging.debug( + logger.debug( f"Stream {self.peer_id}:{stream_id}:" f"Buffer gone, assuming closed" ) @@ -470,7 +473,7 @@ class Yamux(IMuxedConn): else: data = bytes(buffer[:n]) del buffer[:n] - logging.debug( + logger.debug( f"Returning {len(data)} bytes" f"from stream {self.peer_id}:{stream_id}, " f"buffer_len={len(buffer)}" @@ -478,7 +481,7 @@ class Yamux(IMuxedConn): return data # If reset received and buffer is empty, raise reset if stream.reset_received: - logging.debug( + logger.debug( f"Stream {self.peer_id}:{stream_id}:" f"reset_received=True, raising MuxedStreamReset" ) @@ -491,7 +494,7 @@ class Yamux(IMuxedConn): else: data = bytes(buffer[:n]) del buffer[:n] - logging.debug( + logger.debug( f"Returning {len(data)} bytes" f"from stream {self.peer_id}:{stream_id}, " f"buffer_len={len(buffer)}" @@ -499,21 +502,21 @@ class Yamux(IMuxedConn): return data # Check if stream is closed if stream.closed: - logging.debug( + logger.debug( f"Stream {self.peer_id}:{stream_id}:" f"closed=True, raising MuxedStreamReset" ) raise MuxedStreamReset("Stream is reset or closed") # Check if recv_closed and buffer empty if stream.recv_closed: - logging.debug( + logger.debug( f"Stream {self.peer_id}:{stream_id}:" f"recv_closed=True, buffer empty, raising EOF" ) raise MuxedStreamEOF("Stream is closed for receiving") # Wait for data if stream is still open - logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}") + logger.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}") try: await self.stream_events[stream_id].wait() self.stream_events[stream_id] = trio.Event() @@ -528,7 +531,7 @@ class Yamux(IMuxedConn): try: header = await self.secured_conn.read(HEADER_SIZE) if not header or len(header) < HEADER_SIZE: - logging.debug( + logger.debug( f"Connection closed orincomplete header for peer {self.peer_id}" ) self.event_shutting_down.set() @@ -537,7 +540,7 @@ class Yamux(IMuxedConn): version, typ, flags, stream_id, length = struct.unpack( YAMUX_HEADER_FORMAT, header ) - logging.debug( + logger.debug( f"Received header for peer {self.peer_id}:" f"type={typ}, flags={flags}, stream_id={stream_id}," f"length={length}" @@ -558,7 +561,7 @@ class Yamux(IMuxedConn): 0, ) await self.secured_conn.write(ack_header) - logging.debug( + logger.debug( f"Sending stream {stream_id}" f"to channel for peer {self.peer_id}" ) @@ -576,7 +579,7 @@ class Yamux(IMuxedConn): elif typ == TYPE_DATA and flags & FLAG_RST: async with self.streams_lock: if stream_id in self.streams: - logging.debug( + logger.debug( f"Resetting stream {stream_id} for peer {self.peer_id}" ) self.streams[stream_id].closed = True @@ -585,27 +588,27 @@ class Yamux(IMuxedConn): elif typ == TYPE_DATA and flags & FLAG_ACK: async with self.streams_lock: if stream_id in self.streams: - logging.debug( + logger.debug( f"Received ACK for stream" f"{stream_id} for peer {self.peer_id}" ) elif typ == TYPE_GO_AWAY: error_code = length if error_code == GO_AWAY_NORMAL: - logging.debug( + logger.debug( f"Received GO_AWAY for peer" f"{self.peer_id}: Normal termination" ) elif error_code == GO_AWAY_PROTOCOL_ERROR: - logging.error( + logger.error( f"Received GO_AWAY for peer{self.peer_id}: Protocol error" ) elif error_code == GO_AWAY_INTERNAL_ERROR: - logging.error( + logger.error( f"Received GO_AWAY for peer {self.peer_id}: Internal error" ) else: - logging.error( + logger.error( f"Received GO_AWAY for peer {self.peer_id}" f"with unknown error code: {error_code}" ) @@ -614,7 +617,7 @@ class Yamux(IMuxedConn): break elif typ == TYPE_PING: if flags & FLAG_SYN: - logging.debug( + logger.debug( f"Received ping request with value" f"{length} for peer {self.peer_id}" ) @@ -623,7 +626,7 @@ class Yamux(IMuxedConn): ) await self.secured_conn.write(ping_header) elif flags & FLAG_ACK: - logging.debug( + logger.debug( f"Received ping response with value" f"{length} for peer {self.peer_id}" ) @@ -637,7 +640,7 @@ class Yamux(IMuxedConn): self.stream_buffers[stream_id].extend(data) self.stream_events[stream_id].set() if flags & FLAG_FIN: - logging.debug( + logger.debug( f"Received FIN for stream {self.peer_id}:" f"{stream_id}, marking recv_closed" ) @@ -645,7 +648,7 @@ class Yamux(IMuxedConn): if self.streams[stream_id].send_closed: self.streams[stream_id].closed = True except Exception as e: - logging.error(f"Error reading data for stream {stream_id}: {e}") + logger.error(f"Error reading data for stream {stream_id}: {e}") # Mark stream as closed on read error async with self.streams_lock: if stream_id in self.streams: @@ -659,7 +662,7 @@ class Yamux(IMuxedConn): if stream_id in self.streams: stream = self.streams[stream_id] async with stream.window_lock: - logging.debug( + logger.debug( f"Received window update for stream" f"{self.peer_id}:{stream_id}," f" increment: {increment}" @@ -674,7 +677,7 @@ class Yamux(IMuxedConn): and details.get("requested_count") == 2 and details.get("received_count") == 0 ): - logging.info( + logger.info( f"Stream closed cleanly for peer {self.peer_id}" + f" (IncompleteReadError: {details})" ) @@ -682,15 +685,32 @@ class Yamux(IMuxedConn): await self._cleanup_on_error() break else: - logging.error( + logger.error( f"Error in handle_incoming for peer {self.peer_id}: " + f"{type(e).__name__}: {str(e)}" ) else: - logging.error( - f"Error in handle_incoming for peer {self.peer_id}: " - + f"{type(e).__name__}: {str(e)}" - ) + # Handle RawConnError with more nuance + if isinstance(e, RawConnError): + error_msg = str(e) + # If RawConnError is empty, it's likely normal cleanup + if not error_msg.strip(): + logger.info( + f"RawConnError (empty) during cleanup for peer " + f"{self.peer_id} (normal connection shutdown)" + ) + else: + # Log non-empty RawConnError as warning + logger.warning( + f"RawConnError during connection handling for peer " + f"{self.peer_id}: {error_msg}" + ) + else: + # Log all other errors normally + logger.error( + f"Error in handle_incoming for peer {self.peer_id}: " + + f"{type(e).__name__}: {str(e)}" + ) # Don't crash the whole connection for temporary errors if self.event_shutting_down.is_set() or isinstance( e, (RawConnError, OSError) @@ -720,9 +740,9 @@ class Yamux(IMuxedConn): # Close the secured connection try: await self.secured_conn.close() - logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}") + logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}") except Exception as close_error: - logging.error( + logger.error( f"Error closing secured_conn for peer {self.peer_id}: {close_error}" ) @@ -731,14 +751,14 @@ class Yamux(IMuxedConn): # Call on_close callback if provided if self.on_close: - logging.debug(f"Calling on_close for peer {self.peer_id}") + logger.debug(f"Calling on_close for peer {self.peer_id}") try: if inspect.iscoroutinefunction(self.on_close): await self.on_close() else: self.on_close() except Exception as callback_error: - logging.error(f"Error in on_close callback: {callback_error}") + logger.error(f"Error in on_close callback: {callback_error}") # Cancel nursery tasks if self._nursery: diff --git a/libp2p/utils/__init__.py b/libp2p/utils/__init__.py index 2d1ee23e..0f78bfcb 100644 --- a/libp2p/utils/__init__.py +++ b/libp2p/utils/__init__.py @@ -9,6 +9,7 @@ from libp2p.utils.varint import ( read_varint_prefixed_bytes, decode_varint_from_bytes, decode_varint_with_size, + read_length_prefixed_protobuf, ) from libp2p.utils.version import ( get_agent_version, @@ -24,4 +25,5 @@ __all__ = [ "read_varint_prefixed_bytes", "decode_varint_from_bytes", "decode_varint_with_size", + "read_length_prefixed_protobuf", ] diff --git a/libp2p/utils/varint.py b/libp2p/utils/varint.py index 3d8d5a4f..84459efe 100644 --- a/libp2p/utils/varint.py +++ b/libp2p/utils/varint.py @@ -1,7 +1,9 @@ import itertools import logging import math +from typing import BinaryIO +from libp2p.abc import INetStream from libp2p.exceptions import ( ParseError, ) @@ -25,42 +27,41 @@ HIGH_MASK = 2**7 SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7 -def encode_uvarint(number: int) -> bytes: - """Pack `number` into varint bytes.""" - buf = b"" - while True: - towrite = number & 0x7F - number >>= 7 - if number: - buf += bytes((towrite | 0x80,)) - else: - buf += bytes((towrite,)) +def encode_uvarint(value: int) -> bytes: + """Encode an unsigned integer as a varint.""" + if value < 0: + raise ValueError("Cannot encode negative value as uvarint") + + result = bytearray() + while value >= 0x80: + result.append((value & 0x7F) | 0x80) + value >>= 7 + result.append(value & 0x7F) + return bytes(result) + + +def decode_uvarint(data: bytes) -> int: + """Decode a varint from bytes.""" + if not data: + raise ParseError("Unexpected end of data") + + result = 0 + shift = 0 + + for byte in data: + result |= (byte & 0x7F) << shift + if (byte & 0x80) == 0: break - return buf + shift += 7 + if shift >= 64: + raise ValueError("Varint too long") + + return result def decode_varint_from_bytes(data: bytes) -> int: - """ - Decode a varint from bytes and return the value. - - This is a synchronous version of decode_uvarint_from_stream for already-read bytes. - """ - res = 0 - for shift in itertools.count(0, 7): - if shift > SHIFT_64_BIT_MAX: - raise ParseError("Integer is too large...") - - if not data: - raise ParseError("Unexpected end of data") - - value = data[0] - data = data[1:] - - res += (value & LOW_MASK) << shift - - if not value & HIGH_MASK: - break - return res + """Decode a varint from bytes (alias for decode_uvarint for backward comp).""" + return decode_uvarint(data) async def decode_uvarint_from_stream(reader: Reader) -> int: @@ -84,34 +85,33 @@ async def decode_uvarint_from_stream(reader: Reader) -> int: def decode_varint_with_size(data: bytes) -> tuple[int, int]: """ - Decode a varint from bytes and return (value, bytes_consumed). - Returns (0, 0) if the data doesn't start with a valid varint. + Decode a varint from bytes and return both the value and the number of bytes + consumed. + + Returns: + Tuple[int, int]: (value, bytes_consumed) + """ - try: - # Calculate how many bytes the varint consumes - varint_size = 0 - for i, byte in enumerate(data): - varint_size += 1 - if (byte & 0x80) == 0: - break + result = 0 + shift = 0 + bytes_consumed = 0 - if varint_size == 0: - return 0, 0 + for byte in data: + result |= (byte & 0x7F) << shift + bytes_consumed += 1 + if (byte & 0x80) == 0: + break + shift += 7 + if shift >= 64: + raise ValueError("Varint too long") - # Extract just the varint bytes - varint_bytes = data[:varint_size] - - # Decode the varint - value = decode_varint_from_bytes(varint_bytes) - - return value, varint_size - except Exception: - return 0, 0 + return result, bytes_consumed -def encode_varint_prefixed(msg_bytes: bytes) -> bytes: - varint_len = encode_uvarint(len(msg_bytes)) - return varint_len + msg_bytes +def encode_varint_prefixed(data: bytes) -> bytes: + """Encode data with a varint length prefix.""" + length_bytes = encode_uvarint(len(data)) + return length_bytes + data async def read_varint_prefixed_bytes(reader: Reader) -> bytes: @@ -138,3 +138,95 @@ async def read_delim(reader: Reader) -> bytes: f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes!r}' ) return msg_bytes[:-1] + + +def read_varint_prefixed_bytes_sync( + stream: BinaryIO, max_length: int = 1024 * 1024 +) -> bytes: + """ + Read varint-prefixed bytes from a stream. + + Args: + stream: A stream-like object with a read() method + max_length: Maximum allowed data length to prevent memory exhaustion + + Returns: + bytes: The data without the length prefix + + Raises: + ValueError: If the length prefix is invalid or too large + EOFError: If the stream ends unexpectedly + + """ + # Read the varint length prefix + length_bytes = b"" + while True: + byte_data = stream.read(1) + if not byte_data: + raise EOFError("Stream ended while reading varint length prefix") + + length_bytes += byte_data + if byte_data[0] & 0x80 == 0: + break + + # Decode the length + length = decode_uvarint(length_bytes) + + if length > max_length: + raise ValueError(f"Data length {length} exceeds maximum allowed {max_length}") + + # Read the data + data = stream.read(length) + if len(data) != length: + raise EOFError(f"Expected {length} bytes, got {len(data)}") + + return data + + +async def read_length_prefixed_protobuf( + stream: INetStream, use_varint_format: bool = True, max_length: int = 1024 * 1024 +) -> bytes: + """Read a protobuf message from a stream, handling both formats.""" + if use_varint_format: + # Read length-prefixed protobuf message from the stream + # First read the varint length prefix + length_bytes = b"" + while True: + b = await stream.read(1) + if not b: + raise Exception("No length prefix received") + + length_bytes += b + if b[0] & 0x80 == 0: + break + + msg_length = decode_varint_from_bytes(length_bytes) + + if msg_length > max_length: + raise Exception( + f"Message length {msg_length} exceeds maximum allowed {max_length}" + ) + + # Read the protobuf message + data = await stream.read(msg_length) + if len(data) != msg_length: + raise Exception( + f"Incomplete message: expected {msg_length}, got {len(data)}" + ) + + return data + else: + # Read raw protobuf message from the stream + # For raw format, read all available data in one go + data = await stream.read() + + # If we got no data, raise an exception + if not data: + raise Exception("No data received in raw format") + + if len(data) > max_length: + raise Exception( + f"Message length {len(data)} exceeds maximum allowed {max_length}" + ) + + return data diff --git a/newsfragments/748.feature.rst b/newsfragments/748.feature.rst new file mode 100644 index 00000000..199e5b3b --- /dev/null +++ b/newsfragments/748.feature.rst @@ -0,0 +1 @@ + Add lock for read/write to avoid interleaving receiving messages in mplex_stream.py diff --git a/newsfragments/755.performance.rst b/newsfragments/755.performance.rst new file mode 100644 index 00000000..386e661b --- /dev/null +++ b/newsfragments/755.performance.rst @@ -0,0 +1,2 @@ +Added throttling for async topic validators in validate_msg, enforcing a +concurrency limit to prevent resource exhaustion under heavy load. diff --git a/newsfragments/766.internal.rst b/newsfragments/766.internal.rst new file mode 100644 index 00000000..1ecce428 --- /dev/null +++ b/newsfragments/766.internal.rst @@ -0,0 +1 @@ +Pin py-multiaddr dependency to specific git commit db8124e2321f316d3b7d2733c7df11d6ad9c03e6 diff --git a/newsfragments/775.docs.rst b/newsfragments/775.docs.rst new file mode 100644 index 00000000..300b27ca --- /dev/null +++ b/newsfragments/775.docs.rst @@ -0,0 +1 @@ +Clarified the requirement for a trailing newline in newsfragments to pass lint checks. diff --git a/newsfragments/778.bugfix.rst b/newsfragments/778.bugfix.rst new file mode 100644 index 00000000..a18832a4 --- /dev/null +++ b/newsfragments/778.bugfix.rst @@ -0,0 +1 @@ +Fixed incorrect handling of raw protobuf format in identify protocol. The identify example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats. diff --git a/newsfragments/784.bugfix.rst b/newsfragments/784.bugfix.rst new file mode 100644 index 00000000..be96cf2e --- /dev/null +++ b/newsfragments/784.bugfix.rst @@ -0,0 +1 @@ +Fixed incorrect handling of raw protobuf format in identify push protocol. The identify push example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats. diff --git a/newsfragments/784.internal.rst b/newsfragments/784.internal.rst new file mode 100644 index 00000000..9089938d --- /dev/null +++ b/newsfragments/784.internal.rst @@ -0,0 +1 @@ +Yamux RawConnError Logging Refactor - Improved error handling and debug logging diff --git a/newsfragments/README.md b/newsfragments/README.md index 177d6492..4b54df7c 100644 --- a/newsfragments/README.md +++ b/newsfragments/README.md @@ -18,12 +18,19 @@ Each file should be named like `..rst`, where - `performance` - `removal` -So for example: `123.feature.rst`, `456.bugfix.rst` +So for example: `1024.feature.rst` + +**Important**: Ensure the file ends with a newline character (`\n`) to pass GitHub tox linting checks. + +``` +Added support for Ed25519 key generation in libp2p peer identity creation. + +``` If the PR fixes an issue, use that number here. If there is no issue, then open up the PR first and use the PR number for the newsfragment. -Note that the `towncrier` tool will automatically +**Note** that the `towncrier` tool will automatically reflow your text, so don't try to do any fancy formatting. Run `towncrier build --draft` to get a preview of what the release notes entry will look like in the final release notes. diff --git a/pyproject.toml b/pyproject.toml index 259c6c17..1b9589af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,8 @@ dependencies = [ "exceptiongroup>=1.2.0; python_version < '3.11'", "grpcio>=1.41.0", "lru-dict>=1.1.6", - "multiaddr>=0.0.9", + # "multiaddr>=0.0.9", + "multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@db8124e2321f316d3b7d2733c7df11d6ad9c03e6", "mypy-protobuf>=3.0.0", "noiseprotocol>=0.3.0", "protobuf>=4.21.0,<5.0.0", diff --git a/tests/core/identity/identify/test_identify_integration.py b/tests/core/identity/identify/test_identify_integration.py new file mode 100644 index 00000000..e4ebcba7 --- /dev/null +++ b/tests/core/identity/identify/test_identify_integration.py @@ -0,0 +1,241 @@ +import logging + +import pytest + +from libp2p.custom_types import TProtocol +from libp2p.identity.identify.identify import ( + AGENT_VERSION, + ID, + PROTOCOL_VERSION, + _multiaddr_to_bytes, + identify_handler_for, + parse_identify_response, +) +from tests.utils.factories import host_pair_factory + +logger = logging.getLogger("libp2p.identity.identify-integration-test") + + +@pytest.mark.trio +async def test_identify_protocol_varint_format_integration(security_protocol): + """Test identify protocol with varint format in real network scenario.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=True) + ) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + assert result.listen_addrs == [ + _multiaddr_to_bytes(addr) for addr in host_a.get_addrs() + ] + + +@pytest.mark.trio +async def test_identify_protocol_raw_format_integration(security_protocol): + """Test identify protocol with raw format in real network scenario.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=False) + ) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + assert result.listen_addrs == [ + _multiaddr_to_bytes(addr) for addr in host_a.get_addrs() + ] + + +@pytest.mark.trio +async def test_identify_default_format_behavior(security_protocol): + """Test identify protocol uses correct default format.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Use default identify handler (should use varint format) + host_a.set_stream_handler(ID, identify_handler_for(host_a)) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_cross_format_compatibility_varint_to_raw(security_protocol): + """Test varint dialer with raw listener compatibility.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Host A uses raw format + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=False) + ) + + # Host B makes request (will automatically detect format) + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response (should work with automatic format detection) + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_cross_format_compatibility_raw_to_varint(security_protocol): + """Test raw dialer with varint listener compatibility.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Host A uses varint format + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=True) + ) + + # Host B makes request (will automatically detect format) + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response (should work with automatic format detection) + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_format_detection_robustness(security_protocol): + """Test identify protocol format detection is robust with various message sizes.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Test both formats with different message sizes + for use_varint in [True, False]: + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=use_varint) + ) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_large_message_handling(security_protocol): + """Test identify protocol handles large messages with many protocols.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Add many protocols to make the message larger + async def dummy_handler(stream): + pass + + for i in range(10): + host_a.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler) + + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=True) + ) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_message_equivalence_real_network(security_protocol): + """Test that both formats produce equivalent messages in real network.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Test varint format + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=True) + ) + stream_varint = await host_b.new_stream(host_a.get_id(), (ID,)) + response_varint = await stream_varint.read(8192) + await stream_varint.close() + + # Test raw format + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=False) + ) + stream_raw = await host_b.new_stream(host_a.get_id(), (ID,)) + response_raw = await stream_raw.read(8192) + await stream_raw.close() + + # Parse both responses + result_varint = parse_identify_response(response_varint) + result_raw = parse_identify_response(response_raw) + + # Both should produce identical parsed results + assert result_varint.agent_version == result_raw.agent_version + assert result_varint.protocol_version == result_raw.protocol_version + assert result_varint.public_key == result_raw.public_key + assert result_varint.listen_addrs == result_raw.listen_addrs diff --git a/tests/core/identity/identify/test_identify_parsing.py b/tests/core/identity/identify/test_identify_parsing.py deleted file mode 100644 index d76d82a1..00000000 --- a/tests/core/identity/identify/test_identify_parsing.py +++ /dev/null @@ -1,410 +0,0 @@ -import pytest - -from libp2p.identity.identify.identify import ( - _mk_identify_protobuf, -) -from libp2p.identity.identify.pb.identify_pb2 import ( - Identify, -) -from libp2p.io.abc import Closer, Reader, Writer -from libp2p.utils.varint import ( - decode_varint_from_bytes, - encode_varint_prefixed, -) -from tests.utils.factories import ( - host_pair_factory, -) - - -class MockStream(Reader, Writer, Closer): - """Mock stream for testing identify protocol compatibility.""" - - def __init__(self, data: bytes): - self.data = data - self.position = 0 - self.closed = False - - async def read(self, n: int | None = None) -> bytes: - if self.closed or self.position >= len(self.data): - return b"" - if n is None: - n = len(self.data) - self.position - result = self.data[self.position : self.position + n] - self.position += len(result) - return result - - async def write(self, data: bytes) -> None: - # Mock write - just store the data - pass - - async def close(self) -> None: - self.closed = True - - -def create_identify_message(host, observed_multiaddr=None): - """Create an identify protobuf message.""" - return _mk_identify_protobuf(host, observed_multiaddr) - - -def create_new_format_message(identify_msg): - """Create a new format (length-prefixed) identify message.""" - msg_bytes = identify_msg.SerializeToString() - return encode_varint_prefixed(msg_bytes) - - -def create_old_format_message(identify_msg): - """Create an old format (raw protobuf) identify message.""" - return identify_msg.SerializeToString() - - -async def read_new_format_message(stream) -> bytes: - """Read a new format (length-prefixed) identify message.""" - # Read varint length prefix - length_bytes = b"" - while True: - b = await stream.read(1) - if not b: - break - length_bytes += b - if b[0] & 0x80 == 0: - break - - if not length_bytes: - raise ValueError("No length prefix received") - - msg_length = decode_varint_from_bytes(length_bytes) - - # Read the protobuf message - response = await stream.read(msg_length) - if len(response) != msg_length: - raise ValueError("Incomplete message received") - - return response - - -async def read_old_format_message(stream) -> bytes: - """Read an old format (raw protobuf) identify message.""" - # Read all available data - response = b"" - while True: - chunk = await stream.read(4096) - if not chunk: - break - response += chunk - - return response - - -async def read_compatible_message(stream) -> bytes: - """Read an identify message in either old or new format.""" - # Try to read a few bytes to detect the format - first_bytes = await stream.read(10) - if not first_bytes: - raise ValueError("No data received") - - # Try to decode as varint length prefix (new format) - try: - msg_length = decode_varint_from_bytes(first_bytes) - - # Validate that the length is reasonable (not too large) - if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB - # Calculate how many bytes the varint consumed - varint_len = 0 - for i, byte in enumerate(first_bytes): - varint_len += 1 - if (byte & 0x80) == 0: - break - - # Read the remaining protobuf message - remaining_bytes = await stream.read( - msg_length - (len(first_bytes) - varint_len) - ) - if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len): - message_data = first_bytes[varint_len:] + remaining_bytes - - # Try to parse as protobuf to validate - try: - Identify().ParseFromString(message_data) - return message_data - except Exception: - # If protobuf parsing fails, fall back to old format - pass - except Exception: - pass - - # Fall back to old format (raw protobuf) - response = first_bytes - - # Read more data if available - while True: - chunk = await stream.read(4096) - if not chunk: - break - response += chunk - - return response - - -async def read_compatible_message_simple(stream) -> bytes: - """Read a message in either old or new format (simplified version for testing).""" - # Try to read a few bytes to detect the format - first_bytes = await stream.read(10) - if not first_bytes: - raise ValueError("No data received") - - # Try to decode as varint length prefix (new format) - try: - msg_length = decode_varint_from_bytes(first_bytes) - - # Validate that the length is reasonable (not too large) - if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB - # Calculate how many bytes the varint consumed - varint_len = 0 - for i, byte in enumerate(first_bytes): - varint_len += 1 - if (byte & 0x80) == 0: - break - - # Read the remaining message - remaining_bytes = await stream.read( - msg_length - (len(first_bytes) - varint_len) - ) - if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len): - return first_bytes[varint_len:] + remaining_bytes - except Exception: - pass - - # Fall back to old format (raw data) - response = first_bytes - - # Read more data if available - while True: - chunk = await stream.read(4096) - if not chunk: - break - response += chunk - - return response - - -def detect_format(data): - """Detect if data is in new or old format (varint-prefixed or raw protobuf).""" - if not data: - return "unknown" - - # Try to decode as varint - try: - msg_length = decode_varint_from_bytes(data) - - # Validate that the length is reasonable - if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB - # Calculate varint length - varint_len = 0 - for i, byte in enumerate(data): - varint_len += 1 - if (byte & 0x80) == 0: - break - - # Check if we have enough data for the message - if len(data) >= varint_len + msg_length: - # Additional check: try to parse the message as protobuf - try: - message_data = data[varint_len : varint_len + msg_length] - Identify().ParseFromString(message_data) - return "new" - except Exception: - # If protobuf parsing fails, it's probably not a valid new format - pass - except Exception: - pass - - # If varint decoding fails or length is unreasonable, assume old format - return "old" - - -@pytest.mark.trio -async def test_identify_new_format_compatibility(security_protocol): - """Test that identify protocol works with new format (length-prefixed) messages.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create new format message - new_format_data = create_new_format_message(identify_msg) - - # Create mock stream with new format data - stream = MockStream(new_format_data) - - # Read using new format reader - response = await read_new_format_message(stream) - - # Parse the response - parsed_msg = Identify() - parsed_msg.ParseFromString(response) - - # Verify the message content - assert parsed_msg.protocol_version == identify_msg.protocol_version - assert parsed_msg.agent_version == identify_msg.agent_version - assert parsed_msg.public_key == identify_msg.public_key - - -@pytest.mark.trio -async def test_identify_old_format_compatibility(security_protocol): - """Test that identify protocol works with old format (raw protobuf) messages.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create old format message - old_format_data = create_old_format_message(identify_msg) - - # Create mock stream with old format data - stream = MockStream(old_format_data) - - # Read using old format reader - response = await read_old_format_message(stream) - - # Parse the response - parsed_msg = Identify() - parsed_msg.ParseFromString(response) - - # Verify the message content - assert parsed_msg.protocol_version == identify_msg.protocol_version - assert parsed_msg.agent_version == identify_msg.agent_version - assert parsed_msg.public_key == identify_msg.public_key - - -@pytest.mark.trio -async def test_identify_backward_compatibility_old_format(security_protocol): - """Test backward compatibility reader with old format messages.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create old format message - old_format_data = create_old_format_message(identify_msg) - - # Create mock stream with old format data - stream = MockStream(old_format_data) - - # Read using old format reader (which should work reliably) - response = await read_old_format_message(stream) - - # Parse the response - parsed_msg = Identify() - parsed_msg.ParseFromString(response) - - # Verify the message content - assert parsed_msg.protocol_version == identify_msg.protocol_version - assert parsed_msg.agent_version == identify_msg.agent_version - assert parsed_msg.public_key == identify_msg.public_key - - -@pytest.mark.trio -async def test_identify_backward_compatibility_new_format(security_protocol): - """Test backward compatibility reader with new format messages.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create new format message - new_format_data = create_new_format_message(identify_msg) - - # Create mock stream with new format data - stream = MockStream(new_format_data) - - # Read using new format reader (which should work reliably) - response = await read_new_format_message(stream) - - # Parse the response - parsed_msg = Identify() - parsed_msg.ParseFromString(response) - - # Verify the message content - assert parsed_msg.protocol_version == identify_msg.protocol_version - assert parsed_msg.agent_version == identify_msg.agent_version - assert parsed_msg.public_key == identify_msg.public_key - - -@pytest.mark.trio -async def test_identify_format_detection(security_protocol): - """Test that the format detection works correctly.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Test new format detection - new_format_data = create_new_format_message(identify_msg) - format_type = detect_format(new_format_data) - assert format_type == "new", "New format should be detected correctly" - - # Test old format detection - old_format_data = create_old_format_message(identify_msg) - format_type = detect_format(old_format_data) - assert format_type == "old", "Old format should be detected correctly" - - -@pytest.mark.trio -async def test_identify_error_handling(security_protocol): - """Test error handling for malformed messages.""" - from libp2p.exceptions import ParseError - - # Test with empty data - stream = MockStream(b"") - with pytest.raises(ValueError, match="No data received"): - await read_compatible_message(stream) - - # Test with incomplete varint - stream = MockStream(b"\x80") # Incomplete varint - with pytest.raises(ParseError, match="Unexpected end of data"): - await read_new_format_message(stream) - - # Test with invalid protobuf data - stream = MockStream(b"\x05invalid") # Length prefix but invalid protobuf - with pytest.raises(Exception): # Should fail when parsing protobuf - response = await read_new_format_message(stream) - Identify().ParseFromString(response) - - -@pytest.mark.trio -async def test_identify_message_equivalence(security_protocol): - """Test that old and new format messages are equivalent.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create both formats - new_format_data = create_new_format_message(identify_msg) - old_format_data = create_old_format_message(identify_msg) - - # Extract the protobuf message from new format - varint_len = 0 - for i, byte in enumerate(new_format_data): - varint_len += 1 - if (byte & 0x80) == 0: - break - - new_format_protobuf = new_format_data[varint_len:] - - # The protobuf messages should be identical - assert new_format_protobuf == old_format_data, ( - "Protobuf messages should be identical in both formats" - ) diff --git a/tests/core/identity/identify_push/test_identify_push_integration.py b/tests/core/identity/identify_push/test_identify_push_integration.py new file mode 100644 index 00000000..9ee38b10 --- /dev/null +++ b/tests/core/identity/identify_push/test_identify_push_integration.py @@ -0,0 +1,552 @@ +import logging + +import pytest +import trio + +from libp2p.custom_types import TProtocol +from libp2p.identity.identify_push.identify_push import ( + ID_PUSH, + identify_push_handler_for, + push_identify_to_peer, + push_identify_to_peers, +) +from tests.utils.factories import host_pair_factory + +logger = logging.getLogger("libp2p.identity.identify-push-integration-test") + + +@pytest.mark.trio +async def test_identify_push_protocol_varint_format_integration(security_protocol): + """Test identify/push protocol with varint format in real network scenario.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Add some protocols to host_b so it has something to push + async def dummy_handler(stream): + pass + + host_b.set_stream_handler(TProtocol("/test/protocol/1"), dummy_handler) + host_b.set_stream_handler(TProtocol("/test/protocol/2"), dummy_handler) + + # Set up identify/push handler on host_a + host_a.set_stream_handler( + ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True) + ) + + # Push identify information from host_b to host_a + await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True) + + # Wait a bit for the push to complete + await trio.sleep(0.1) + + # Verify that host_a's peerstore was updated + peerstore_a = host_a.get_peerstore() + peer_id_b = host_b.get_id() + + # Check that addresses were added + addrs = peerstore_a.addrs(peer_id_b) + assert len(addrs) > 0 + + # Check that protocols were added + protocols = peerstore_a.get_protocols(peer_id_b) + assert protocols is not None + # The protocols should include the dummy protocols we added + assert len(protocols) >= 2 # Should include the dummy protocols + + +@pytest.mark.trio +async def test_identify_push_protocol_raw_format_integration(security_protocol): + """Test identify/push protocol with raw format in real network scenario.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Add some protocols to both hosts + async def dummy_handler(stream): + pass + + host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler) + host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler) + + # Set up identify/push handler on host_a + host_a.set_stream_handler( + ID_PUSH, identify_push_handler_for(host_a, use_varint_format=False) + ) + + # Push identify information from host_b to host_a + await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=False) + + # Wait a bit for the push to complete + await trio.sleep(0.1) + + # Verify that host_a's peerstore was updated + peerstore_a = host_a.get_peerstore() + peer_id_b = host_b.get_id() + + # Check that addresses were added + addrs = peerstore_a.addrs(peer_id_b) + assert len(addrs) > 0 + + # Check that protocols were added + protocols = peerstore_a.get_protocols(peer_id_b) + assert protocols is not None + assert len(protocols) >= 1 # Should include the dummy protocol + + +@pytest.mark.trio +async def test_identify_push_default_format_behavior(security_protocol): + """Test identify/push protocol uses correct default format.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Add some protocols to both hosts + async def dummy_handler(stream): + pass + + host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler) + host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler) + + # Use default identify/push handler (should use varint format) + host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a)) + + # Push identify information from host_b to host_a + await push_identify_to_peer(host_b, host_a.get_id()) + + # Wait a bit for the push to complete + await trio.sleep(0.1) + + # Verify that host_a's peerstore was updated + peerstore_a = host_a.get_peerstore() + peer_id_b = host_b.get_id() + + # Check that protocols were added + protocols = peerstore_a.get_protocols(peer_id_b) + assert protocols is not None + assert len(protocols) >= 1 # Should include the dummy protocol + + +@pytest.mark.trio +async def test_identify_push_cross_format_compatibility_varint_to_raw( + security_protocol, +): + """Test varint pusher with raw listener compatibility.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Use an event to signal when handler is ready + handler_ready = trio.Event() + + # Create a wrapper handler that signals when ready + original_handler = identify_push_handler_for(host_a, use_varint_format=False) + + async def wrapped_handler(stream): + handler_ready.set() # Signal that handler is ready + await original_handler(stream) + + # Host A uses raw format with wrapped handler + host_a.set_stream_handler(ID_PUSH, wrapped_handler) + + # Host B pushes with varint format (should fail gracefully) + success = await push_identify_to_peer( + host_b, host_a.get_id(), use_varint_format=True + ) + # This should fail due to format mismatch + # Note: The format detection might be more robust than expected + # so we just check that the operation completes + assert isinstance(success, bool) + + +@pytest.mark.trio +async def test_identify_push_cross_format_compatibility_raw_to_varint( + security_protocol, +): + """Test raw pusher with varint listener compatibility.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Use an event to signal when handler is ready + handler_ready = trio.Event() + + # Create a wrapper handler that signals when ready + original_handler = identify_push_handler_for(host_a, use_varint_format=True) + + async def wrapped_handler(stream): + handler_ready.set() # Signal that handler is ready + await original_handler(stream) + + # Host A uses varint format with wrapped handler + host_a.set_stream_handler(ID_PUSH, wrapped_handler) + + # Host B pushes with raw format (should fail gracefully) + success = await push_identify_to_peer( + host_b, host_a.get_id(), use_varint_format=False + ) + # This should fail due to format mismatch + # Note: The format detection might be more robust than expected + # so we just check that the operation completes + assert isinstance(success, bool) + + +@pytest.mark.trio +async def test_identify_push_multiple_peers_integration(security_protocol): + """Test identify/push protocol with multiple peers.""" + # Create two hosts using the factory + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Create a third host following the pattern from test_identify_push.py + import multiaddr + + from libp2p import new_host + from libp2p.crypto.secp256k1 import create_new_key_pair + from libp2p.peer.peerinfo import info_from_p2p_addr + + # Create a new key pair for host_c + key_pair_c = create_new_key_pair() + host_c = new_host(key_pair=key_pair_c) + + # Set up identify/push handlers on all hosts + host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a)) + host_b.set_stream_handler(ID_PUSH, identify_push_handler_for(host_b)) + host_c.set_stream_handler(ID_PUSH, identify_push_handler_for(host_c)) + + # Start listening on a random port using the run context manager + listen_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0") + async with host_c.run([listen_addr]): + # Connect host_c to host_a and host_b using the correct pattern + await host_c.connect(info_from_p2p_addr(host_a.get_addrs()[0])) + await host_c.connect(info_from_p2p_addr(host_b.get_addrs()[0])) + + # Push identify information from host_a to all connected peers + await push_identify_to_peers(host_a) + + # Wait a bit for the push to complete + await trio.sleep(0.1) + + # Check that host_b's peerstore has been updated with host_a's information + peerstore_b = host_b.get_peerstore() + peer_id_a = host_a.get_id() + + # Check that the peer is in the peerstore + assert peer_id_a in peerstore_b.peer_ids() + + # Check that host_c's peerstore has been updated with host_a's information + peerstore_c = host_c.get_peerstore() + + # Check that the peer is in the peerstore + assert peer_id_a in peerstore_c.peer_ids() + + # Test for push_identify to only connected peers and not all peers + # Disconnect a from c. + await host_c.disconnect(host_a.get_id()) + + await push_identify_to_peers(host_c) + + # Wait a bit for the push to complete + await trio.sleep(0.1) + + # Check that host_a's peerstore has not been updated with host_c's info + assert host_c.get_id() not in host_a.get_peerstore().peer_ids() + # Check that host_b's peerstore has been updated with host_c's info + assert host_c.get_id() in host_b.get_peerstore().peer_ids() + + +@pytest.mark.trio +async def test_identify_push_large_message_handling(security_protocol): + """Test identify/push protocol handles large messages with many protocols.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Add many protocols to make the message larger + async def dummy_handler(stream): + pass + + for i in range(10): + host_b.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler) + + # Also add some protocols to host_a to ensure it has protocols to push + for i in range(5): + host_a.set_stream_handler(TProtocol(f"/test/protocol/a{i}"), dummy_handler) + + # Set up identify/push handler on host_a + host_a.set_stream_handler( + ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True) + ) + + # Push identify information from host_b to host_a + success = await push_identify_to_peer( + host_b, host_a.get_id(), use_varint_format=True + ) + assert success + + # Wait a bit for the push to complete + await trio.sleep(0.1) + + # Verify that host_a's peerstore was updated with all protocols + peerstore_a = host_a.get_peerstore() + peer_id_b = host_b.get_id() + protocols = peerstore_a.get_protocols(peer_id_b) + assert protocols is not None + assert len(protocols) >= 10 # Should include the dummy protocols + + +@pytest.mark.trio +async def test_identify_push_peerstore_update_completeness(security_protocol): + """Test that identify/push updates all relevant peerstore information.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Add some protocols to both hosts + async def dummy_handler(stream): + pass + + host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler) + host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler) + + # Set up identify/push handler on host_a + host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a)) + + # Push identify information from host_b to host_a + await push_identify_to_peer(host_b, host_a.get_id()) + + # Wait a bit for the push to complete + await trio.sleep(0.1) + + # Verify that host_a's peerstore was updated + peerstore_a = host_a.get_peerstore() + peer_id_b = host_b.get_id() + + # Check that protocols were added + protocols = peerstore_a.get_protocols(peer_id_b) + assert protocols is not None + assert len(protocols) > 0 + + # Check that addresses were added + addrs = peerstore_a.addrs(peer_id_b) + assert len(addrs) > 0 + + # Check that public key was added + pubkey = peerstore_a.pubkey(peer_id_b) + assert pubkey is not None + assert pubkey.serialize() == host_b.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_push_concurrent_requests(security_protocol): + """Test identify/push protocol handles concurrent requests properly.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Add some protocols to both hosts + async def dummy_handler(stream): + pass + + host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler) + host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler) + + # Set up identify/push handler on host_a + host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a)) + + # Make multiple concurrent push requests + results = [] + + async def push_identify(): + result = await push_identify_to_peer(host_b, host_a.get_id()) + results.append(result) + + # Run multiple concurrent pushes using nursery + async with trio.open_nursery() as nursery: + for _ in range(3): + nursery.start_soon(push_identify) + + # All should succeed + assert len(results) == 3 + assert all(results) + + # Wait a bit for the pushes to complete + await trio.sleep(0.1) + + # Verify that host_a's peerstore was updated + peerstore_a = host_a.get_peerstore() + peer_id_b = host_b.get_id() + protocols = peerstore_a.get_protocols(peer_id_b) + assert protocols is not None + assert len(protocols) > 0 + + +@pytest.mark.trio +async def test_identify_push_stream_handling(security_protocol): + """Test identify/push protocol properly handles stream lifecycle.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Add some protocols to both hosts + async def dummy_handler(stream): + pass + + host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler) + host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler) + + # Set up identify/push handler on host_a + host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a)) + + # Push identify information from host_b to host_a + success = await push_identify_to_peer(host_b, host_a.get_id()) + assert success + + # Wait a bit for the push to complete + await trio.sleep(0.1) + + # Verify that host_a's peerstore was updated + peerstore_a = host_a.get_peerstore() + peer_id_b = host_b.get_id() + protocols = peerstore_a.get_protocols(peer_id_b) + assert protocols is not None + assert len(protocols) > 0 + + +@pytest.mark.trio +async def test_identify_push_error_handling(security_protocol): + """Test identify/push protocol handles errors gracefully.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Create a handler that raises an exception but catches it to prevent test + # failure + async def error_handler(stream): + try: + await stream.close() + raise Exception("Test error") + except Exception: + # Catch the exception to prevent it from propagating up + pass + + host_a.set_stream_handler(ID_PUSH, error_handler) + + # Push should complete (message sent) but handler should fail gracefully + success = await push_identify_to_peer(host_b, host_a.get_id()) + assert success # The push operation itself succeeds (message sent) + + # Wait a bit for the handler to process + await trio.sleep(0.1) + + # Verify that the error was handled gracefully (no test failure) + # The handler caught the exception and didn't propagate it + + +@pytest.mark.trio +async def test_identify_push_message_equivalence_real_network(security_protocol): + """Test that both formats produce equivalent peerstore updates in real network.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Add some protocols to both hosts + async def dummy_handler(stream): + pass + + host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler) + host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler) + + # Test varint format + host_a.set_stream_handler( + ID_PUSH, identify_push_handler_for(host_a, use_varint_format=True) + ) + await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True) + + # Wait a bit for the push to complete + await trio.sleep(0.1) + + # Get peerstore state after varint push + peerstore_a = host_a.get_peerstore() + peer_id_b = host_b.get_id() + protocols_varint = peerstore_a.get_protocols(peer_id_b) + addrs_varint = peerstore_a.addrs(peer_id_b) + + # Clear peerstore for next test + peerstore_a.clear_addrs(peer_id_b) + peerstore_a.clear_protocol_data(peer_id_b) + + # Test raw format + host_a.set_stream_handler( + ID_PUSH, identify_push_handler_for(host_a, use_varint_format=False) + ) + await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=False) + + # Wait a bit for the push to complete + await trio.sleep(0.1) + + # Get peerstore state after raw push + protocols_raw = peerstore_a.get_protocols(peer_id_b) + addrs_raw = peerstore_a.addrs(peer_id_b) + + # Both should produce equivalent peerstore updates + # Check that both formats successfully updated protocols + assert protocols_varint is not None + assert protocols_raw is not None + assert len(protocols_varint) > 0 + assert len(protocols_raw) > 0 + + # Check that both formats successfully updated addresses + assert addrs_varint is not None + assert addrs_raw is not None + assert len(addrs_varint) > 0 + assert len(addrs_raw) > 0 + + # Both should contain the same essential information + # (exact address lists might differ due to format-specific handling) + assert set(protocols_varint) == set(protocols_raw) + + +@pytest.mark.trio +async def test_identify_push_with_observed_address(security_protocol): + """Test identify/push protocol includes observed address information.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Add some protocols to both hosts + async def dummy_handler(stream): + pass + + host_a.set_stream_handler(TProtocol("/test/protocol/a"), dummy_handler) + host_b.set_stream_handler(TProtocol("/test/protocol/b"), dummy_handler) + + # Set up identify/push handler on host_a + host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a)) + + # Get host_b's address as observed by host_a + from multiaddr import Multiaddr + + host_b_addr = host_b.get_addrs()[0] + observed_multiaddr = Multiaddr(str(host_b_addr)) + + # Push identify information with observed address + await push_identify_to_peer( + host_b, host_a.get_id(), observed_multiaddr=observed_multiaddr + ) + + # Wait a bit for the push to complete + await trio.sleep(0.1) + + # Verify that host_a's peerstore was updated + peerstore_a = host_a.get_peerstore() + peer_id_b = host_b.get_id() + + # Check that addresses were added + addrs = peerstore_a.addrs(peer_id_b) + assert len(addrs) > 0 + + # The observed address should be among the stored addresses + addr_strings = [str(addr) for addr in addrs] + assert str(observed_multiaddr) in addr_strings diff --git a/tests/core/pubsub/test_pubsub.py b/tests/core/pubsub/test_pubsub.py index 81389ed1..e674dbc0 100644 --- a/tests/core/pubsub/test_pubsub.py +++ b/tests/core/pubsub/test_pubsub.py @@ -5,10 +5,12 @@ import inspect from typing import ( NamedTuple, ) +from unittest.mock import patch import pytest import trio +from libp2p.custom_types import AsyncValidatorFn from libp2p.exceptions import ( ValidationError, ) @@ -243,7 +245,37 @@ async def test_get_msg_validators(): ((False, True), (True, False), (True, True)), ) @pytest.mark.trio -async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed): +async def test_validate_msg_with_throttle_condition( + is_topic_1_val_passed, is_topic_2_val_passed +): + CONCURRENCY_LIMIT = 10 + + state = { + "concurrency_counter": 0, + "max_observed": 0, + } + lock = trio.Lock() + + async def mock_run_async_validator( + self, + func: AsyncValidatorFn, + msg_forwarder: ID, + msg: rpc_pb2.Message, + results: list[bool], + ) -> None: + async with self._validator_semaphore: + async with lock: + state["concurrency_counter"] += 1 + if state["concurrency_counter"] > state["max_observed"]: + state["max_observed"] = state["concurrency_counter"] + + try: + result = await func(msg_forwarder, msg) + results.append(result) + finally: + async with lock: + state["concurrency_counter"] -= 1 + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool: @@ -280,11 +312,19 @@ async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed): seqno=b"\x00" * 8, ) - if is_topic_1_val_passed and is_topic_2_val_passed: - await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) - else: - with pytest.raises(ValidationError): + with patch( + "libp2p.pubsub.pubsub.Pubsub._run_async_validator", + new=mock_run_async_validator, + ): + if is_topic_1_val_passed and is_topic_2_val_passed: await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) + else: + with pytest.raises(ValidationError): + await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) + + assert state["max_observed"] <= CONCURRENCY_LIMIT, ( + f"Max concurrency observed: {state['max_observed']}" + ) @pytest.mark.trio diff --git a/tests/core/stream_muxer/test_read_write_lock.py b/tests/core/stream_muxer/test_read_write_lock.py new file mode 100644 index 00000000..621f3841 --- /dev/null +++ b/tests/core/stream_muxer/test_read_write_lock.py @@ -0,0 +1,590 @@ +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import trio +from trio.testing import wait_all_tasks_blocked + +from libp2p.stream_muxer.exceptions import ( + MuxedConnUnavailable, +) +from libp2p.stream_muxer.mplex.constants import HeaderTags +from libp2p.stream_muxer.mplex.datastructures import StreamID +from libp2p.stream_muxer.mplex.exceptions import ( + MplexStreamClosed, + MplexStreamEOF, + MplexStreamReset, +) +from libp2p.stream_muxer.mplex.mplex_stream import MplexStream + + +class MockMuxedConn: + """A mock Mplex connection for testing purposes.""" + + def __init__(self): + self.sent_messages = [] + self.streams: dict[StreamID, MplexStream] = {} + self.streams_lock = trio.Lock() + self.is_unavailable = False + + async def send_message( + self, flag: HeaderTags, data: bytes | None, stream_id: StreamID + ) -> None: + """Mocks sending a message over the connection.""" + if self.is_unavailable: + raise MuxedConnUnavailable("Connection is unavailable") + self.sent_messages.append((flag, data, stream_id)) + # Yield to allow other tasks to run + await trio.lowlevel.checkpoint() + + def get_remote_address(self) -> tuple[str, int]: + """Mocks getting the remote address.""" + return "127.0.0.1", 4001 + + +@pytest.fixture +async def mplex_stream(): + """Provides a fully initialized MplexStream and its communication channels.""" + # Use a buffered channel to prevent deadlocks in simple tests + send_chan, recv_chan = trio.open_memory_channel(10) + stream_id = StreamID(1, is_initiator=True) + muxed_conn = MockMuxedConn() + stream = MplexStream("test-stream", stream_id, cast(Any, muxed_conn), recv_chan) + muxed_conn.streams[stream_id] = stream + + yield stream, send_chan, muxed_conn + + # Cleanup: Close channels and reset stream state + await send_chan.aclose() + await recv_chan.aclose() + # Reset stream state to prevent cross-test contamination + stream.event_local_closed = trio.Event() + stream.event_remote_closed = trio.Event() + stream.event_reset = trio.Event() + + +# =============================================== +# 1. Tests for Stream-Level Lock Integration +# =============================================== + + +@pytest.mark.trio +async def test_stream_write_is_protected_by_rwlock(mplex_stream): + """Verify that stream.write() acquires and releases the write lock.""" + stream, _, muxed_conn = mplex_stream + + # Mock lock methods + original_acquire = stream.rw_lock.acquire_write + original_release = stream.rw_lock.release_write + + stream.rw_lock.acquire_write = AsyncMock(wraps=original_acquire) + stream.rw_lock.release_write = MagicMock(wraps=original_release) + + await stream.write(b"test data") + + stream.rw_lock.acquire_write.assert_awaited_once() + stream.rw_lock.release_write.assert_called_once() + + # Verify the message was actually sent + assert len(muxed_conn.sent_messages) == 1 + flag, data, stream_id = muxed_conn.sent_messages[0] + assert flag == HeaderTags.MessageInitiator + assert data == b"test data" + assert stream_id == stream.stream_id + + +@pytest.mark.trio +async def test_stream_read_is_protected_by_rwlock(mplex_stream): + """Verify that stream.read() acquires and releases the read lock.""" + stream, send_chan, _ = mplex_stream + + # Mock lock methods + original_acquire = stream.rw_lock.acquire_read + original_release = stream.rw_lock.release_read + + stream.rw_lock.acquire_read = AsyncMock(wraps=original_acquire) + stream.rw_lock.release_read = AsyncMock(wraps=original_release) + + await send_chan.send(b"hello") + result = await stream.read(5) + + stream.rw_lock.acquire_read.assert_awaited_once() + stream.rw_lock.release_read.assert_awaited_once() + assert result == b"hello" + + +@pytest.mark.trio +async def test_multiple_readers_can_coexist(mplex_stream): + """Verify multiple readers can operate concurrently.""" + stream, send_chan, _ = mplex_stream + + # Send enough data for both reads + await send_chan.send(b"data1") + await send_chan.send(b"data2") + + # Track lock acquisition order + acquisition_order = [] + release_order = [] + + # Patch lock methods to track concurrency + original_acquire = stream.rw_lock.acquire_read + original_release = stream.rw_lock.release_read + + async def tracked_acquire(): + nonlocal acquisition_order + acquisition_order.append("start") + await original_acquire() + acquisition_order.append("acquired") + + async def tracked_release(): + nonlocal release_order + release_order.append("start") + await original_release() + release_order.append("released") + + with ( + patch.object( + stream.rw_lock, "acquire_read", side_effect=tracked_acquire, autospec=True + ), + patch.object( + stream.rw_lock, "release_read", side_effect=tracked_release, autospec=True + ), + ): + # Execute concurrent reads + async with trio.open_nursery() as nursery: + nursery.start_soon(stream.read, 5) + nursery.start_soon(stream.read, 5) + + # Verify both reads happened + assert acquisition_order.count("start") == 2 + assert acquisition_order.count("acquired") == 2 + assert release_order.count("start") == 2 + assert release_order.count("released") == 2 + + +@pytest.mark.trio +async def test_writer_blocks_readers(mplex_stream): + """Verify that a writer blocks all readers and new readers queue behind.""" + stream, send_chan, _ = mplex_stream + + writer_acquired = trio.Event() + readers_ready = trio.Event() + writer_finished = trio.Event() + all_readers_started = trio.Event() + all_readers_done = trio.Event() + + counters = {"reader_start_count": 0, "reader_done_count": 0} + reader_target = 3 + reader_start_lock = trio.Lock() + + # Patch write lock to control test flow + original_acquire_write = stream.rw_lock.acquire_write + original_release_write = stream.rw_lock.release_write + + async def tracked_acquire_write(): + await original_acquire_write() + writer_acquired.set() + # Wait for readers to queue up + await readers_ready.wait() + + # Must be synchronous since real release_write is sync + def tracked_release_write(): + original_release_write() + writer_finished.set() + + with ( + patch.object( + stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write + ), + patch.object( + stream.rw_lock, "release_write", side_effect=tracked_release_write + ), + ): + async with trio.open_nursery() as nursery: + # Start writer + nursery.start_soon(stream.write, b"test") + await writer_acquired.wait() + + # Start readers + async def reader_task(): + async with reader_start_lock: + counters["reader_start_count"] += 1 + if counters["reader_start_count"] == reader_target: + all_readers_started.set() + + try: + # This will block until data is available + await stream.read(5) + except (MplexStreamReset, MplexStreamEOF): + pass + finally: + async with reader_start_lock: + counters["reader_done_count"] += 1 + if counters["reader_done_count"] == reader_target: + all_readers_done.set() + + for _ in range(reader_target): + nursery.start_soon(reader_task) + + # Wait until all readers are started + await all_readers_started.wait() + + # Let the writer finish and release the lock + readers_ready.set() + await writer_finished.wait() + + # Send data to unblock the readers + for i in range(reader_target): + await send_chan.send(b"data" + str(i).encode()) + + # Wait for all readers to finish + await all_readers_done.wait() + + +@pytest.mark.trio +async def test_writer_waits_for_readers(mplex_stream): + """Verify a writer waits for existing readers to complete.""" + stream, send_chan, _ = mplex_stream + readers_started = trio.Event() + writer_entered = trio.Event() + writer_acquiring = trio.Event() + readers_finished = trio.Event() + + # Send data for readers + await send_chan.send(b"data1") + await send_chan.send(b"data2") + + # Patch read lock to control test flow + original_acquire_read = stream.rw_lock.acquire_read + + async def tracked_acquire_read(): + await original_acquire_read() + readers_started.set() + # Wait until readers are allowed to finish + await readers_finished.wait() + + # Patch write lock to detect when writer is blocked + original_acquire_write = stream.rw_lock.acquire_write + + async def tracked_acquire_write(): + writer_acquiring.set() + await original_acquire_write() + writer_entered.set() + + with ( + patch.object(stream.rw_lock, "acquire_read", side_effect=tracked_acquire_read), + patch.object( + stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write + ), + ): + async with trio.open_nursery() as nursery: + # Start readers + nursery.start_soon(stream.read, 5) + nursery.start_soon(stream.read, 5) + + # Wait for at least one reader to acquire the lock + await readers_started.wait() + + # Start writer (should block) + nursery.start_soon(stream.write, b"test") + + # Wait for writer to start acquiring lock + await writer_acquiring.wait() + + # Verify writer hasn't entered critical section + assert not writer_entered.is_set() + + # Allow readers to finish + readers_finished.set() + + # Verify writer can proceed + await writer_entered.wait() + + +@pytest.mark.trio +async def test_lock_behavior_during_cancellation(mplex_stream): + """Verify that a lock is released when a task holding it is cancelled.""" + stream, _, _ = mplex_stream + + reader_acquired_lock = trio.Event() + + async def cancellable_reader(task_status): + async with stream.rw_lock.read_lock(): + reader_acquired_lock.set() + task_status.started() + # Wait indefinitely until cancelled. + await trio.sleep_forever() + + async with trio.open_nursery() as nursery: + # Start the reader and wait for it to acquire the lock. + await nursery.start(cancellable_reader) + await reader_acquired_lock.wait() + + # Now that the reader has the lock, cancel the nursery. + # This will cancel the reader task, and its lock should be released. + nursery.cancel_scope.cancel() + + # After the nursery is cancelled, the reader should have released the lock. + # To verify, we try to acquire a write lock. If the read lock was not + # released, this will time out. + with trio.move_on_after(1) as cancel_scope: + async with stream.rw_lock.write_lock(): + pass + if cancel_scope.cancelled_caught: + pytest.fail( + "Write lock could not be acquired after a cancelled reader, " + "indicating the read lock was not released." + ) + + +@pytest.mark.trio +async def test_concurrent_read_write_sequence(mplex_stream): + """Verify complex sequence of interleaved reads and writes.""" + stream, send_chan, _ = mplex_stream + results = [] + # Use a mock to intercept writes and feed them back to the read channel + original_write = stream.write + + reader1_finished = trio.Event() + writer1_finished = trio.Event() + reader2_finished = trio.Event() + + async def mocked_write(data: bytes) -> None: + await original_write(data) + # Simulate the other side receiving the data and sending a response + # by putting data into the read channel. + await send_chan.send(data) + + with patch.object(stream, "write", wraps=mocked_write) as patched_write: + async with trio.open_nursery() as nursery: + # Test scenario: + # 1. Reader 1 starts, waits for data. + # 2. Writer 1 writes, which gets fed back to the stream. + # 3. Reader 2 starts, reads what Writer 1 wrote. + # 4. Writer 2 writes. + + async def reader1(): + nonlocal results + results.append("R1 start") + data = await stream.read(5) + results.append(data) + results.append("R1 done") + reader1_finished.set() + + async def writer1(): + nonlocal results + await reader1_finished.wait() + results.append("W1 start") + await stream.write(b"write1") + results.append("W1 done") + writer1_finished.set() + + async def reader2(): + nonlocal results + await writer1_finished.wait() + # This will read the data from writer1 + results.append("R2 start") + data = await stream.read(6) + results.append(data) + results.append("R2 done") + reader2_finished.set() + + async def writer2(): + nonlocal results + await reader2_finished.wait() + results.append("W2 start") + await stream.write(b"write2") + results.append("W2 done") + + # Execute sequence + nursery.start_soon(reader1) + nursery.start_soon(writer1) + nursery.start_soon(reader2) + nursery.start_soon(writer2) + + await send_chan.send(b"data1") + + # Verify sequence and that write was called + assert patched_write.call_count == 2 + assert results == [ + "R1 start", + b"data1", + "R1 done", + "W1 start", + "W1 done", + "R2 start", + b"write1", + "R2 done", + "W2 start", + "W2 done", + ] + + +# =============================================== +# 2. Tests for Reset, EOF, and Close Interactions +# =============================================== + + +@pytest.mark.trio +async def test_read_after_remote_close_triggers_eof(mplex_stream): + """Verify reading from a remotely closed stream returns EOF correctly.""" + stream, send_chan, _ = mplex_stream + + # Send some data that can be read first + await send_chan.send(b"data") + # Close the channel to signify no more data will ever arrive + await send_chan.aclose() + + # Mark the stream as remotely closed + stream.event_remote_closed.set() + + # The first read should succeed, consuming the buffered data + data = await stream.read(4) + assert data == b"data" + + # Now that the buffer is empty and the channel is closed, this should raise EOF + with pytest.raises(MplexStreamEOF): + await stream.read(1) + + +@pytest.mark.trio +async def test_read_on_closed_stream_raises_eof(mplex_stream): + """Test that reading from a closed stream with no data raises EOF.""" + stream, send_chan, _ = mplex_stream + stream.event_remote_closed.set() + await send_chan.aclose() # Ensure the channel is closed + + # Reading from a stream that is closed and has no data should raise EOF + with pytest.raises(MplexStreamEOF): + await stream.read(100) + + +@pytest.mark.trio +async def test_write_to_locally_closed_stream_raises(mplex_stream): + """Verify writing to a locally closed stream raises MplexStreamClosed.""" + stream, _, _ = mplex_stream + stream.event_local_closed.set() + + with pytest.raises(MplexStreamClosed): + await stream.write(b"this should fail") + + +@pytest.mark.trio +async def test_read_from_reset_stream_raises(mplex_stream): + """Verify reading from a reset stream raises MplexStreamReset.""" + stream, _, _ = mplex_stream + stream.event_reset.set() + + with pytest.raises(MplexStreamReset): + await stream.read(10) + + +@pytest.mark.trio +async def test_write_to_reset_stream_raises(mplex_stream): + """Verify writing to a reset stream raises MplexStreamClosed.""" + stream, _, _ = mplex_stream + # A stream reset implies it's also locally closed. + await stream.reset() + + # The `write` method checks `event_local_closed`, which `reset` sets. + with pytest.raises(MplexStreamClosed): + await stream.write(b"this should also fail") + + +@pytest.mark.trio +async def test_stream_reset_cleans_up_resources(mplex_stream): + """Verify reset() cleans up stream state and resources.""" + stream, _, muxed_conn = mplex_stream + stream_id = stream.stream_id + + assert stream_id in muxed_conn.streams + await stream.reset() + + assert stream.event_reset.is_set() + assert stream.event_local_closed.is_set() + assert stream.event_remote_closed.is_set() + assert (HeaderTags.ResetInitiator, None, stream_id) in muxed_conn.sent_messages + assert stream_id not in muxed_conn.streams + # Verify the underlying data channel is closed + with pytest.raises(trio.ClosedResourceError): + await stream.incoming_data_channel.receive() + + +# =============================================== +# 3. Rigorous Concurrency Tests with Events +# =============================================== + + +@pytest.mark.trio +async def test_writer_is_blocked_by_reader_using_events(mplex_stream): + """Verify a writer must wait for a reader using trio.Event for synchronization.""" + stream, _, _ = mplex_stream + + reader_has_lock = trio.Event() + writer_finished = trio.Event() + + async def reader(): + async with stream.rw_lock.read_lock(): + reader_has_lock.set() + # Hold the lock until the writer has finished its attempt + await writer_finished.wait() + + async def writer(): + await reader_has_lock.wait() + # This call will now block until the reader releases the lock + await stream.write(b"data") + writer_finished.set() + + async with trio.open_nursery() as nursery: + nursery.start_soon(reader) + nursery.start_soon(writer) + + # Verify writer is blocked + await wait_all_tasks_blocked() + assert not writer_finished.is_set() + + # Signal the reader to finish + writer_finished.set() + + +@pytest.mark.trio +async def test_multiple_readers_can_read_concurrently_using_events(mplex_stream): + """Verify that multiple readers can acquire a read lock simultaneously.""" + stream, _, _ = mplex_stream + + counters = {"readers_in_critical_section": 0} + lock = trio.Lock() # To safely mutate the counter + + reader1_acquired = trio.Event() + reader2_acquired = trio.Event() + all_readers_finished = trio.Event() + + async def concurrent_reader(event_to_set: trio.Event): + async with stream.rw_lock.read_lock(): + async with lock: + counters["readers_in_critical_section"] += 1 + event_to_set.set() + # Wait until all readers have finished before exiting the lock context + await all_readers_finished.wait() + async with lock: + counters["readers_in_critical_section"] -= 1 + + async with trio.open_nursery() as nursery: + nursery.start_soon(concurrent_reader, reader1_acquired) + nursery.start_soon(concurrent_reader, reader2_acquired) + + # Wait for both readers to acquire their locks + await reader1_acquired.wait() + await reader2_acquired.wait() + + # Check that both were in the critical section at the same time + async with lock: + assert counters["readers_in_critical_section"] == 2 + + # Signal for all readers to finish + all_readers_finished.set() + + # Verify they exit cleanly + await wait_all_tasks_blocked() + async with lock: + assert counters["readers_in_critical_section"] == 0