From 51c7d53453d4a1b30b3552495151611807a8aa71 Mon Sep 17 00:00:00 2001 From: acul71 Date: Thu, 20 Mar 2025 03:31:56 +0100 Subject: [PATCH] test: add test_complete_remote_address_delegation_chain --- tests/core/identity/identify/test_protocol.py | 99 +++++++++++++++++-- 1 file changed, 89 insertions(+), 10 deletions(-) diff --git a/tests/core/identity/identify/test_protocol.py b/tests/core/identity/identify/test_protocol.py index e3777531..af87ab1a 100644 --- a/tests/core/identity/identify/test_protocol.py +++ b/tests/core/identity/identify/test_protocol.py @@ -14,12 +14,34 @@ from libp2p.identity.identify.protocol import ( PROTOCOL_VERSION, _mk_identify_protobuf, _multiaddr_to_bytes, + _remote_address_to_multiaddr, ) from tests.factories import ( host_pair_factory, ) -logger = logging.getLogger("libp2p.identity.identify-test") + +def clean_multiaddr(addr: Multiaddr) -> Multiaddr: + """ + Clean a multiaddr by removing the '/p2p/' part if it exists. + + Args: + addr: The multiaddr to clean + + Returns: + The cleaned multiaddr + """ + return Multiaddr.join( + *( + addr.split()[:-1] + if str(addr.split()[-1]).startswith("/p2p/") + else addr.split() + ) + ) + + +# logger = logging.getLogger("libp2p.identity.identify-test") +logger = logging.getLogger(__name__) @pytest.mark.trio @@ -29,8 +51,8 @@ async def test_identify_protocol(security_protocol): host_b, ): # Here, host_b is the requester and host_a is the responder. - # observed_addr represent host_b’s address as observed by host_a - # (i.e., the address from which host_b’s request was received). + # observed_addr represent host_b's address as observed by host_a + # (i.e., the address from which host_b's request was received). stream = await host_b.new_stream(host_a.get_id(), (ID,)) response = await stream.read() await stream.close() @@ -59,13 +81,7 @@ async def test_identify_protocol(security_protocol): # TODO: use decapsulateCode(protocols('p2p').code) # when the Multiaddr class will implement it host_b_addr = host_b.get_addrs()[0] - cleaned_addr = Multiaddr.join( - *( - host_b_addr.split()[:-1] - if str(host_b_addr.split()[-1]).startswith("/p2p/") - else host_b_addr.split() - ) - ) + cleaned_addr = clean_multiaddr(host_b_addr) logger.debug("observed_addr: %s", Multiaddr(identify_response.observed_addr)) logger.debug("host_b.get_addrs()[0]: %s", host_b.get_addrs()[0]) @@ -77,3 +93,66 @@ async def test_identify_protocol(security_protocol): # sanity check assert identify_response == _mk_identify_protobuf(host_a, cleaned_addr) + + +@pytest.mark.trio +async def test_complete_remote_address_delegation_chain(security_protocol): + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + logger.debug( + "test_complete_remote_address_delegation_chain security_protocol: %s", + security_protocol, + ) + + # Create a stream from host_a to host_b + stream = await host_a.new_stream(host_b.get_id(), (ID,)) + + # Get references to all components in the delegation chain + mplex_stream = stream.muxed_stream + swarm_conn = host_a.get_network().connections[host_b.get_id()] + mplex = swarm_conn.muxed_conn + secure_session = mplex.secured_conn + raw_connection = secure_session.conn + trio_tcp_stream = raw_connection.stream + + # Get remote addresses at each layer + stream_addr = stream.get_remote_address() + muxed_stream_addr = stream.muxed_stream.get_remote_address() + mplex_addr = mplex_stream.muxed_conn.get_remote_address() + secure_session_addr = mplex.secured_conn.get_remote_address() + raw_connection_addr = secure_session.conn.get_remote_address() + trio_tcp_stream_addr = raw_connection.stream.get_remote_address() + socket_addr = trio_tcp_stream.stream.socket.getpeername() + + # Log all addresses + logger.debug("NetStream address: %s", stream_addr) + logger.debug("MplexStream address: %s", muxed_stream_addr) + logger.debug("Mplex address: %s", mplex_addr) + logger.debug("SecureSession address: %s", secure_session_addr) + logger.debug("RawConnection address: %s", raw_connection_addr) + logger.debug("TrioTCPStream address: %s", trio_tcp_stream_addr) + logger.debug("Socket address: %s", socket_addr) + + # Verify complete delegation chain + assert ( + stream_addr + == muxed_stream_addr + == mplex_addr + == secure_session_addr + == raw_connection_addr + == trio_tcp_stream_addr + == socket_addr + ) + + # Convert to multiaddr and verify it matches host_b's cleaned address + remote_address_multiaddr = _remote_address_to_multiaddr(stream_addr) + host_b_addr = clean_multiaddr(host_b.get_addrs()[0]) + + logger.debug("Remote address multiaddr: %s", remote_address_multiaddr) + logger.debug("Host B cleaned address: %s", host_b_addr) + + assert remote_address_multiaddr == host_b_addr + + await stream.close()