From 798229cd3a50bc779b55897c5e61d17faf452c11 Mon Sep 17 00:00:00 2001 From: acul71 Date: Tue, 25 Feb 2025 04:58:47 +0100 Subject: [PATCH] feat: add observed_addr to identify protocol --- libp2p/abc.py | 6 +++ libp2p/identity/identify/protocol.py | 37 ++++++++++++++----- tests/core/identity/identify/test_protocol.py | 30 +++++++++++++-- 3 files changed, 60 insertions(+), 13 deletions(-) diff --git a/libp2p/abc.py b/libp2p/abc.py index b09e4322..688b1623 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -1131,6 +1131,12 @@ class IHost(ABC): """ + @abstractmethod + def get_peerstore(self) -> IPeerStore: + """ + :return: the peerstore of the host + """ + @abstractmethod def get_connected_peers(self) -> list[ID]: """ diff --git a/libp2p/identity/identify/protocol.py b/libp2p/identity/identify/protocol.py index 26d36f84..98e6effe 100644 --- a/libp2p/identity/identify/protocol.py +++ b/libp2p/identity/identify/protocol.py @@ -1,4 +1,7 @@ import logging +from typing import ( + Optional, +) from multiaddr import ( Multiaddr, @@ -15,9 +18,6 @@ from libp2p.custom_types import ( from libp2p.network.stream.exceptions import ( StreamClosed, ) -from libp2p.network.stream.net_stream_interface import ( - INetStream, -) from libp2p.utils import ( get_agent_version, ) @@ -26,7 +26,9 @@ from .pb.identify_pb2 import ( Identify, ) -logger = logging.getLogger("libp2p.identity.identify") +# Not sure I can do this or I break a pattern +# logger = logging.getLogger("libp2p.identity.identify") +logger = logging.getLogger(__name__) ID = TProtocol("/ipfs/id/1.0.0") PROTOCOL_VERSION = "ipfs/0.1.0" @@ -37,28 +39,45 @@ def _multiaddr_to_bytes(maddr: Multiaddr) -> bytes: return maddr.to_bytes() -def _mk_identify_protobuf(host: IHost) -> Identify: +def _mk_identify_protobuf( + host: IHost, observed_multiaddr: Optional[Multiaddr] +) -> Identify: public_key = host.get_public_key() laddrs = host.get_addrs() protocols = host.get_mux().get_protocols() + observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b"" return Identify( protocol_version=PROTOCOL_VERSION, agent_version=AGENT_VERSION, public_key=public_key.serialize(), listen_addrs=map(_multiaddr_to_bytes, laddrs), - # TODO send observed address from ``stream`` - observed_addr=b"", + observed_addr=observed_addr, protocols=protocols, ) def identify_handler_for(host: IHost) -> StreamHandlerFn: async def handle_identify(stream: INetStream) -> None: - peer_id = stream.muxed_conn.peer_id + # get observed address from ``stream`` + # class Swarm(Service, INetworkService): + # TODO: Connection and `peer_id` are 1-1 mapping in our implementation, + # whereas in Go one `peer_id` may point to multiple connections. + # connections: dict[ID, INetConn] + # Luca: So I'm assuming that the connection is 1-1 mapping for now + peer_id = stream.muxed_conn.peer_id # remote peer_id + peer_store = host.get_peerstore() # get the peer store from the host + remote_peer_multiaddrs = peer_store.addrs( + peer_id + ) # get the Multiaddrs for the remote peer_id + logger.debug("multiaddrs of remote peer is %s", remote_peer_multiaddrs) logger.debug("received a request for %s from %s", ID, peer_id) - protobuf = _mk_identify_protobuf(host) + # Select the first address if available, else None + observed_multiaddr = ( + remote_peer_multiaddrs[0] if remote_peer_multiaddrs else None + ) + protobuf = _mk_identify_protobuf(host, observed_multiaddr) response = protobuf.SerializeToString() try: diff --git a/tests/core/identity/identify/test_protocol.py b/tests/core/identity/identify/test_protocol.py index fd22c8a5..e532da3d 100644 --- a/tests/core/identity/identify/test_protocol.py +++ b/tests/core/identity/identify/test_protocol.py @@ -1,4 +1,9 @@ +import logging + import pytest +from multiaddr import ( + Multiaddr, +) from libp2p.identity.identify.pb.identify_pb2 import ( Identify, @@ -14,6 +19,8 @@ from tests.factories import ( host_pair_factory, ) +logger = logging.getLogger("libp2p.identity.identify-test") + @pytest.mark.trio async def test_identify_protocol(security_protocol): @@ -28,8 +35,8 @@ async def test_identify_protocol(security_protocol): identify_response = Identify() identify_response.ParseFromString(response) - # sanity check - assert identify_response == _mk_identify_protobuf(host_a) + logger.debug("host_a: %s", host_a.get_addrs()) + logger.debug("host_b: %s", host_b.get_addrs()) # Check protocol version assert identify_response.protocol_version == PROTOCOL_VERSION @@ -45,8 +52,23 @@ async def test_identify_protocol(security_protocol): map(_multiaddr_to_bytes, host_a.get_addrs()) ) - # TODO: Check observed address - # assert identify_response.observed_addr == host_b.get_addrs()[0] + # Check observed address + host_b_addr = host_b.get_addrs()[0] + cleaned_addr = Multiaddr.join( + *( + host_b_addr.split()[:-1] + if str(host_b_addr.split()[-1]).startswith("/p2p/") + else host_b_addr.split() + ) + ) + + logger.debug("observed_addr: %s", Multiaddr(identify_response.observed_addr)) + logger.debug("host_b.get_addrs()[0]: %s", host_b.get_addrs()[0]) + logger.debug("cleaned_addr= %s", cleaned_addr) + assert identify_response.observed_addr == _multiaddr_to_bytes(cleaned_addr) # Check protocols assert set(identify_response.protocols) == set(host_a.get_mux().get_protocols()) + + # sanity check + assert identify_response == _mk_identify_protobuf(host_a, cleaned_addr)