diff --git a/libp2p/peer/peerstore.py b/libp2p/peer/peerstore.py index e23e014a..40cb7893 100644 --- a/libp2p/peer/peerstore.py +++ b/libp2p/peer/peerstore.py @@ -2,6 +2,7 @@ from collections import ( defaultdict, ) from collections.abc import ( + AsyncIterable, Sequence, ) from typing import ( @@ -11,6 +12,8 @@ from typing import ( from multiaddr import ( Multiaddr, ) +import trio +from trio import MemoryReceiveChannel, MemorySendChannel from libp2p.abc import ( IPeerStore, @@ -40,6 +43,7 @@ class PeerStore(IPeerStore): def __init__(self) -> None: self.peer_data_map = defaultdict(PeerData) + self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {} def peer_info(self, peer_id: ID) -> PeerInfo: """ @@ -178,6 +182,13 @@ class PeerStore(IPeerStore): peer_data.set_ttl(ttl) peer_data.update_last_identified() + if peer_id in self.addr_update_channels: + for addr in addrs: + try: + self.addr_update_channels[peer_id].send_nowait(addr) + except trio.WouldBlock: + pass # Or consider logging / dropping / replacing stream + def addrs(self, peer_id: ID) -> list[Multiaddr]: """ :param peer_id: peer ID to get addrs for @@ -217,6 +228,25 @@ class PeerStore(IPeerStore): peer_data.clear_addrs() return output + async def addr_stream(self, peer_id: ID) -> AsyncIterable[Multiaddr]: + """ + Returns an async stream of newly added addresses for the given peer. + + This function allows consumers to subscribe to address updates for a peer + and receive each new address as it is added via `add_addr` or `add_addrs`. + + :param peer_id: The ID of the peer to monitor address updates for. + :return: An async iterator yielding Multiaddr instances as they are added. + """ + send: MemorySendChannel[Multiaddr] + receive: MemoryReceiveChannel[Multiaddr] + + send, receive = trio.open_memory_channel(0) + self.addr_update_channels[peer_id] = send + + async for addr in receive: + yield addr + # -------KEY-BOOK--------- def add_pubkey(self, peer_id: ID, pubkey: PublicKey) -> None: diff --git a/tests/core/peer/test_peerstore.py b/tests/core/peer/test_peerstore.py index b0d8ed81..85fc1863 100644 --- a/tests/core/peer/test_peerstore.py +++ b/tests/core/peer/test_peerstore.py @@ -2,6 +2,7 @@ import time import pytest from multiaddr import Multiaddr +import trio from libp2p.peer.id import ID from libp2p.peer.peerstore import ( @@ -89,3 +90,36 @@ def test_peers(): store.add_addr(ID(b"peer3"), Multiaddr("/ip4/127.0.0.1/tcp/4001"), 10) assert set(store.peer_ids()) == {ID(b"peer1"), ID(b"peer2"), ID(b"peer3")} + + +@pytest.mark.trio +async def test_addr_stream_yields_new_addrs(): + store = PeerStore() + peer_id = ID(b"peer1") + addr1 = Multiaddr("/ip4/127.0.0.1/tcp/4001") + addr2 = Multiaddr("/ip4/127.0.0.1/tcp/4002") + + # 🔧 Pre-initialize peer in peer_data_map + # store.add_addr(peer_id, Multiaddr("/ip4/127.0.0.1/tcp/0"), ttl=1) + + collected = [] + + async def consume_addrs(): + async for addr in store.addr_stream(peer_id): + collected.append(addr) + if len(collected) == 2: + break + + async with trio.open_nursery() as nursery: + nursery.start_soon(consume_addrs) + await trio.sleep(2) # Give time for the stream to start + + store.add_addr(peer_id, addr1, ttl=10) + await trio.sleep(0.2) + store.add_addr(peer_id, addr2, ttl=10) + await trio.sleep(0.2) + + # After collecting expected addresses, cancel the stream + nursery.cancel_scope.cancel() + + assert collected == [addr1, addr2]