diff --git a/examples/identify/identify.py b/examples/identify/identify.py index 4882d2c3..60ccb75c 100644 --- a/examples/identify/identify.py +++ b/examples/identify/identify.py @@ -72,13 +72,46 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/") format_name = "length-prefixed" if use_varint_format else "raw protobuf" + format_flag = "--raw-format" if not use_varint_format else "" print( f"First host listening (using {format_name} format). " f"Run this from another console:\n\n" - f"identify-demo " - f"-d {client_addr}\n" + f"identify-demo {format_flag} -d {client_addr}\n" ) print("Waiting for incoming identify request...") + + # Add a custom handler to show connection events + async def custom_identify_handler(stream): + peer_id = stream.muxed_conn.peer_id + print(f"\nšŸ”— Received identify request from peer: {peer_id}") + + # Show remote address in multiaddr format + try: + from libp2p.identity.identify.identify import ( + _remote_address_to_multiaddr, + ) + + remote_address = stream.get_remote_address() + if remote_address: + observed_multiaddr = _remote_address_to_multiaddr( + remote_address + ) + # Add the peer ID to create a complete multiaddr + complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}" + print(f" Remote address: {complete_multiaddr}") + else: + print(f" Remote address: {remote_address}") + except Exception: + print(f" Remote address: {stream.get_remote_address()}") + + # Call the original handler + await identify_handler(stream) + + print(f"āœ… Successfully processed identify request from {peer_id}") + + # Replace the handler with our custom one + host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler) + await trio.sleep_forever() else: @@ -93,25 +126,99 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No info = info_from_p2p_addr(maddr) print(f"Second host connecting to peer: {info.peer_id}") - await host_b.connect(info) + try: + await host_b.connect(info) + except Exception as e: + error_msg = str(e) + if "unable to connect" in error_msg or "SwarmException" in error_msg: + print(f"\nāŒ Cannot connect to peer: {info.peer_id}") + print(f" Address: {destination}") + print(f" Error: {error_msg}") + print( + "\nšŸ’” Make sure the peer is running and the address is correct." + ) + return + else: + # Re-raise other exceptions + raise + stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,)) try: print("Starting identify protocol...") - # Read the complete response (could be either format) - # Read a larger chunk to get all the data before stream closes - response = await stream.read(8192) # Read enough data in one go + # Read the response properly based on the format + if use_varint_format: + # For length-prefixed format, read varint length first + from libp2p.utils.varint import decode_varint_from_bytes + + # Read varint length prefix + length_bytes = b"" + while True: + b = await stream.read(1) + if not b: + raise Exception("Stream closed while reading varint length") + length_bytes += b + if b[0] & 0x80 == 0: + break + + msg_length = decode_varint_from_bytes(length_bytes) + print(f"Expected message length: {msg_length} bytes") + + # Read the protobuf message + response = await stream.read(msg_length) + if len(response) != msg_length: + raise Exception( + f"Incomplete message: expected {msg_length} bytes, " + f"got {len(response)}" + ) + + # Combine length prefix and message + full_response = length_bytes + response + else: + # For raw format, read all available data + response = await stream.read(8192) + full_response = response await stream.close() # Parse the response using the robust protocol-level function # This handles both old and new formats automatically - identify_msg = parse_identify_response(response) + identify_msg = parse_identify_response(full_response) print_identify_response(identify_msg) except Exception as e: - print(f"Identify protocol error: {e}") + error_msg = str(e) + print(f"Identify protocol error: {error_msg}") + + # Check for specific format mismatch errors + if "Error parsing message" in error_msg or "DecodeError" in error_msg: + print("\n" + "=" * 60) + print("FORMAT MISMATCH DETECTED!") + print("=" * 60) + if use_varint_format: + print( + "You are using length-prefixed format (default) but the " + "listener" + ) + print("is using raw protobuf format.") + print( + "\nTo fix this, run the dialer with the --raw-format flag:" + ) + print(f"identify-demo --raw-format -d {destination}") + else: + print("You are using raw protobuf format but the listener") + print("is using length-prefixed format (default).") + print( + "\nTo fix this, run the dialer without the --raw-format " + "flag:" + ) + print(f"identify-demo -d {destination}") + print("=" * 60) + else: + import traceback + + traceback.print_exc() return diff --git a/tests/core/identity/identify/test_identify_integration.py b/tests/core/identity/identify/test_identify_integration.py new file mode 100644 index 00000000..e4ebcba7 --- /dev/null +++ b/tests/core/identity/identify/test_identify_integration.py @@ -0,0 +1,241 @@ +import logging + +import pytest + +from libp2p.custom_types import TProtocol +from libp2p.identity.identify.identify import ( + AGENT_VERSION, + ID, + PROTOCOL_VERSION, + _multiaddr_to_bytes, + identify_handler_for, + parse_identify_response, +) +from tests.utils.factories import host_pair_factory + +logger = logging.getLogger("libp2p.identity.identify-integration-test") + + +@pytest.mark.trio +async def test_identify_protocol_varint_format_integration(security_protocol): + """Test identify protocol with varint format in real network scenario.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=True) + ) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + assert result.listen_addrs == [ + _multiaddr_to_bytes(addr) for addr in host_a.get_addrs() + ] + + +@pytest.mark.trio +async def test_identify_protocol_raw_format_integration(security_protocol): + """Test identify protocol with raw format in real network scenario.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=False) + ) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + assert result.listen_addrs == [ + _multiaddr_to_bytes(addr) for addr in host_a.get_addrs() + ] + + +@pytest.mark.trio +async def test_identify_default_format_behavior(security_protocol): + """Test identify protocol uses correct default format.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Use default identify handler (should use varint format) + host_a.set_stream_handler(ID, identify_handler_for(host_a)) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_cross_format_compatibility_varint_to_raw(security_protocol): + """Test varint dialer with raw listener compatibility.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Host A uses raw format + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=False) + ) + + # Host B makes request (will automatically detect format) + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response (should work with automatic format detection) + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_cross_format_compatibility_raw_to_varint(security_protocol): + """Test raw dialer with varint listener compatibility.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Host A uses varint format + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=True) + ) + + # Host B makes request (will automatically detect format) + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response (should work with automatic format detection) + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_format_detection_robustness(security_protocol): + """Test identify protocol format detection is robust with various message sizes.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Test both formats with different message sizes + for use_varint in [True, False]: + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=use_varint) + ) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_large_message_handling(security_protocol): + """Test identify protocol handles large messages with many protocols.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Add many protocols to make the message larger + async def dummy_handler(stream): + pass + + for i in range(10): + host_a.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler) + + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=True) + ) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response content + assert result.agent_version == AGENT_VERSION + assert result.protocol_version == PROTOCOL_VERSION + assert result.public_key == host_a.get_public_key().serialize() + + +@pytest.mark.trio +async def test_identify_message_equivalence_real_network(security_protocol): + """Test that both formats produce equivalent messages in real network.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Test varint format + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=True) + ) + stream_varint = await host_b.new_stream(host_a.get_id(), (ID,)) + response_varint = await stream_varint.read(8192) + await stream_varint.close() + + # Test raw format + host_a.set_stream_handler( + ID, identify_handler_for(host_a, use_varint_format=False) + ) + stream_raw = await host_b.new_stream(host_a.get_id(), (ID,)) + response_raw = await stream_raw.read(8192) + await stream_raw.close() + + # Parse both responses + result_varint = parse_identify_response(response_varint) + result_raw = parse_identify_response(response_raw) + + # Both should produce identical parsed results + assert result_varint.agent_version == result_raw.agent_version + assert result_varint.protocol_version == result_raw.protocol_version + assert result_varint.public_key == result_raw.public_key + assert result_varint.listen_addrs == result_raw.listen_addrs diff --git a/tests/core/identity/identify/test_identify_parsing.py b/tests/core/identity/identify/test_identify_parsing.py deleted file mode 100644 index d76d82a1..00000000 --- a/tests/core/identity/identify/test_identify_parsing.py +++ /dev/null @@ -1,410 +0,0 @@ -import pytest - -from libp2p.identity.identify.identify import ( - _mk_identify_protobuf, -) -from libp2p.identity.identify.pb.identify_pb2 import ( - Identify, -) -from libp2p.io.abc import Closer, Reader, Writer -from libp2p.utils.varint import ( - decode_varint_from_bytes, - encode_varint_prefixed, -) -from tests.utils.factories import ( - host_pair_factory, -) - - -class MockStream(Reader, Writer, Closer): - """Mock stream for testing identify protocol compatibility.""" - - def __init__(self, data: bytes): - self.data = data - self.position = 0 - self.closed = False - - async def read(self, n: int | None = None) -> bytes: - if self.closed or self.position >= len(self.data): - return b"" - if n is None: - n = len(self.data) - self.position - result = self.data[self.position : self.position + n] - self.position += len(result) - return result - - async def write(self, data: bytes) -> None: - # Mock write - just store the data - pass - - async def close(self) -> None: - self.closed = True - - -def create_identify_message(host, observed_multiaddr=None): - """Create an identify protobuf message.""" - return _mk_identify_protobuf(host, observed_multiaddr) - - -def create_new_format_message(identify_msg): - """Create a new format (length-prefixed) identify message.""" - msg_bytes = identify_msg.SerializeToString() - return encode_varint_prefixed(msg_bytes) - - -def create_old_format_message(identify_msg): - """Create an old format (raw protobuf) identify message.""" - return identify_msg.SerializeToString() - - -async def read_new_format_message(stream) -> bytes: - """Read a new format (length-prefixed) identify message.""" - # Read varint length prefix - length_bytes = b"" - while True: - b = await stream.read(1) - if not b: - break - length_bytes += b - if b[0] & 0x80 == 0: - break - - if not length_bytes: - raise ValueError("No length prefix received") - - msg_length = decode_varint_from_bytes(length_bytes) - - # Read the protobuf message - response = await stream.read(msg_length) - if len(response) != msg_length: - raise ValueError("Incomplete message received") - - return response - - -async def read_old_format_message(stream) -> bytes: - """Read an old format (raw protobuf) identify message.""" - # Read all available data - response = b"" - while True: - chunk = await stream.read(4096) - if not chunk: - break - response += chunk - - return response - - -async def read_compatible_message(stream) -> bytes: - """Read an identify message in either old or new format.""" - # Try to read a few bytes to detect the format - first_bytes = await stream.read(10) - if not first_bytes: - raise ValueError("No data received") - - # Try to decode as varint length prefix (new format) - try: - msg_length = decode_varint_from_bytes(first_bytes) - - # Validate that the length is reasonable (not too large) - if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB - # Calculate how many bytes the varint consumed - varint_len = 0 - for i, byte in enumerate(first_bytes): - varint_len += 1 - if (byte & 0x80) == 0: - break - - # Read the remaining protobuf message - remaining_bytes = await stream.read( - msg_length - (len(first_bytes) - varint_len) - ) - if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len): - message_data = first_bytes[varint_len:] + remaining_bytes - - # Try to parse as protobuf to validate - try: - Identify().ParseFromString(message_data) - return message_data - except Exception: - # If protobuf parsing fails, fall back to old format - pass - except Exception: - pass - - # Fall back to old format (raw protobuf) - response = first_bytes - - # Read more data if available - while True: - chunk = await stream.read(4096) - if not chunk: - break - response += chunk - - return response - - -async def read_compatible_message_simple(stream) -> bytes: - """Read a message in either old or new format (simplified version for testing).""" - # Try to read a few bytes to detect the format - first_bytes = await stream.read(10) - if not first_bytes: - raise ValueError("No data received") - - # Try to decode as varint length prefix (new format) - try: - msg_length = decode_varint_from_bytes(first_bytes) - - # Validate that the length is reasonable (not too large) - if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB - # Calculate how many bytes the varint consumed - varint_len = 0 - for i, byte in enumerate(first_bytes): - varint_len += 1 - if (byte & 0x80) == 0: - break - - # Read the remaining message - remaining_bytes = await stream.read( - msg_length - (len(first_bytes) - varint_len) - ) - if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len): - return first_bytes[varint_len:] + remaining_bytes - except Exception: - pass - - # Fall back to old format (raw data) - response = first_bytes - - # Read more data if available - while True: - chunk = await stream.read(4096) - if not chunk: - break - response += chunk - - return response - - -def detect_format(data): - """Detect if data is in new or old format (varint-prefixed or raw protobuf).""" - if not data: - return "unknown" - - # Try to decode as varint - try: - msg_length = decode_varint_from_bytes(data) - - # Validate that the length is reasonable - if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB - # Calculate varint length - varint_len = 0 - for i, byte in enumerate(data): - varint_len += 1 - if (byte & 0x80) == 0: - break - - # Check if we have enough data for the message - if len(data) >= varint_len + msg_length: - # Additional check: try to parse the message as protobuf - try: - message_data = data[varint_len : varint_len + msg_length] - Identify().ParseFromString(message_data) - return "new" - except Exception: - # If protobuf parsing fails, it's probably not a valid new format - pass - except Exception: - pass - - # If varint decoding fails or length is unreasonable, assume old format - return "old" - - -@pytest.mark.trio -async def test_identify_new_format_compatibility(security_protocol): - """Test that identify protocol works with new format (length-prefixed) messages.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create new format message - new_format_data = create_new_format_message(identify_msg) - - # Create mock stream with new format data - stream = MockStream(new_format_data) - - # Read using new format reader - response = await read_new_format_message(stream) - - # Parse the response - parsed_msg = Identify() - parsed_msg.ParseFromString(response) - - # Verify the message content - assert parsed_msg.protocol_version == identify_msg.protocol_version - assert parsed_msg.agent_version == identify_msg.agent_version - assert parsed_msg.public_key == identify_msg.public_key - - -@pytest.mark.trio -async def test_identify_old_format_compatibility(security_protocol): - """Test that identify protocol works with old format (raw protobuf) messages.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create old format message - old_format_data = create_old_format_message(identify_msg) - - # Create mock stream with old format data - stream = MockStream(old_format_data) - - # Read using old format reader - response = await read_old_format_message(stream) - - # Parse the response - parsed_msg = Identify() - parsed_msg.ParseFromString(response) - - # Verify the message content - assert parsed_msg.protocol_version == identify_msg.protocol_version - assert parsed_msg.agent_version == identify_msg.agent_version - assert parsed_msg.public_key == identify_msg.public_key - - -@pytest.mark.trio -async def test_identify_backward_compatibility_old_format(security_protocol): - """Test backward compatibility reader with old format messages.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create old format message - old_format_data = create_old_format_message(identify_msg) - - # Create mock stream with old format data - stream = MockStream(old_format_data) - - # Read using old format reader (which should work reliably) - response = await read_old_format_message(stream) - - # Parse the response - parsed_msg = Identify() - parsed_msg.ParseFromString(response) - - # Verify the message content - assert parsed_msg.protocol_version == identify_msg.protocol_version - assert parsed_msg.agent_version == identify_msg.agent_version - assert parsed_msg.public_key == identify_msg.public_key - - -@pytest.mark.trio -async def test_identify_backward_compatibility_new_format(security_protocol): - """Test backward compatibility reader with new format messages.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create new format message - new_format_data = create_new_format_message(identify_msg) - - # Create mock stream with new format data - stream = MockStream(new_format_data) - - # Read using new format reader (which should work reliably) - response = await read_new_format_message(stream) - - # Parse the response - parsed_msg = Identify() - parsed_msg.ParseFromString(response) - - # Verify the message content - assert parsed_msg.protocol_version == identify_msg.protocol_version - assert parsed_msg.agent_version == identify_msg.agent_version - assert parsed_msg.public_key == identify_msg.public_key - - -@pytest.mark.trio -async def test_identify_format_detection(security_protocol): - """Test that the format detection works correctly.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Test new format detection - new_format_data = create_new_format_message(identify_msg) - format_type = detect_format(new_format_data) - assert format_type == "new", "New format should be detected correctly" - - # Test old format detection - old_format_data = create_old_format_message(identify_msg) - format_type = detect_format(old_format_data) - assert format_type == "old", "Old format should be detected correctly" - - -@pytest.mark.trio -async def test_identify_error_handling(security_protocol): - """Test error handling for malformed messages.""" - from libp2p.exceptions import ParseError - - # Test with empty data - stream = MockStream(b"") - with pytest.raises(ValueError, match="No data received"): - await read_compatible_message(stream) - - # Test with incomplete varint - stream = MockStream(b"\x80") # Incomplete varint - with pytest.raises(ParseError, match="Unexpected end of data"): - await read_new_format_message(stream) - - # Test with invalid protobuf data - stream = MockStream(b"\x05invalid") # Length prefix but invalid protobuf - with pytest.raises(Exception): # Should fail when parsing protobuf - response = await read_new_format_message(stream) - Identify().ParseFromString(response) - - -@pytest.mark.trio -async def test_identify_message_equivalence(security_protocol): - """Test that old and new format messages are equivalent.""" - async with host_pair_factory(security_protocol=security_protocol) as ( - host_a, - host_b, - ): - # Create identify message - identify_msg = create_identify_message(host_a) - - # Create both formats - new_format_data = create_new_format_message(identify_msg) - old_format_data = create_old_format_message(identify_msg) - - # Extract the protobuf message from new format - varint_len = 0 - for i, byte in enumerate(new_format_data): - varint_len += 1 - if (byte & 0x80) == 0: - break - - new_format_protobuf = new_format_data[varint_len:] - - # The protobuf messages should be identical - assert new_format_protobuf == old_format_data, ( - "Protobuf messages should be identical in both formats" - )