diff --git a/libp2p/identity/identify_push/identify_push.py b/libp2p/identity/identify_push/identify_push.py index c649c368..914264ed 100644 --- a/libp2p/identity/identify_push/identify_push.py +++ b/libp2p/identity/identify_push/identify_push.py @@ -40,6 +40,7 @@ logger = logging.getLogger(__name__) ID_PUSH = TProtocol("/ipfs/id/push/1.0.0") PROTOCOL_VERSION = "ipfs/0.1.0" AGENT_VERSION = get_agent_version() +CONCURRENCY_LIMIT = 10 def identify_push_handler_for(host: IHost) -> StreamHandlerFn: @@ -132,7 +133,10 @@ async def _update_peerstore_from_identify( async def push_identify_to_peer( - host: IHost, peer_id: ID, observed_multiaddr: Multiaddr | None = None + host: IHost, + peer_id: ID, + observed_multiaddr: Multiaddr | None = None, + limit: trio.Semaphore = trio.Semaphore(CONCURRENCY_LIMIT), ) -> bool: """ Push an identify message to a specific peer. @@ -146,25 +150,26 @@ async def push_identify_to_peer( True if the push was successful, False otherwise. """ - try: - # Create a new stream to the peer using the identify/push protocol - stream = await host.new_stream(peer_id, [ID_PUSH]) + async with limit: + try: + # Create a new stream to the peer using the identify/push protocol + stream = await host.new_stream(peer_id, [ID_PUSH]) - # Create the identify message - identify_msg = _mk_identify_protobuf(host, observed_multiaddr) - response = identify_msg.SerializeToString() + # Create the identify message + identify_msg = _mk_identify_protobuf(host, observed_multiaddr) + response = identify_msg.SerializeToString() - # Send the identify message - await stream.write(response) + # Send the identify message + await stream.write(response) - # Close the stream - await stream.close() + # Close the stream + await stream.close() - logger.debug("Successfully pushed identify to peer %s", peer_id) - return True - except Exception as e: - logger.error("Error pushing identify to peer %s: %s", peer_id, e) - return False + logger.debug("Successfully pushed identify to peer %s", peer_id) + return True + except Exception as e: + logger.error("Error pushing identify to peer %s: %s", peer_id, e) + return False async def push_identify_to_peers( @@ -179,13 +184,10 @@ async def push_identify_to_peers( """ if peer_ids is None: # Get all connected peers - peer_ids = set(host.get_peerstore().peer_ids()) + peer_ids = set(host.get_connected_peers()) # Push to each peer in parallel using a trio.Nursery - # TODO: Consider using a bounded nursery to limit concurrency - # and avoid overwhelming the network. This can be done by using - # trio.open_nursery(max_concurrent=10) or similar. - # For now, we will use an unbounded nursery for simplicity. + # limiting concurrent connections to 10 async with trio.open_nursery() as nursery: for peer_id in peer_ids: nursery.start_soon(push_identify_to_peer, host, peer_id, observed_multiaddr) diff --git a/newsfragments/621.feature.rst b/newsfragments/621.feature.rst new file mode 100644 index 00000000..7ed27fac --- /dev/null +++ b/newsfragments/621.feature.rst @@ -0,0 +1 @@ +Limit concurrency in `push_identify_to_peers` to prevent resource congestion under high peer counts. diff --git a/pyproject.toml b/pyproject.toml index cf000156..604949fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,3 @@ - [build-system] requires = ["setuptools>=42", "wheel"] build-backend = "setuptools.build_meta" @@ -23,7 +22,7 @@ dependencies = [ "multiaddr>=0.0.9", "mypy-protobuf>=3.0.0", "noiseprotocol>=0.3.0", - "protobuf>=3.20.1,<4.0.0", + "protobuf>=4.21.0,<5.0.0", "pycryptodome>=3.9.2", "pymultihash>=0.8.2", "pynacl>=1.3.0", diff --git a/tests/core/identity/identify_push/test_identify_push.py b/tests/core/identity/identify_push/test_identify_push.py index 1b875e6f..b0ffb677 100644 --- a/tests/core/identity/identify_push/test_identify_push.py +++ b/tests/core/identity/identify_push/test_identify_push.py @@ -1,4 +1,7 @@ import logging +from unittest.mock import ( + patch, +) import pytest import multiaddr @@ -17,6 +20,7 @@ from libp2p.identity.identify.pb.identify_pb2 import ( Identify, ) from libp2p.identity.identify_push.identify_push import ( + CONCURRENCY_LIMIT, ID_PUSH, _update_peerstore_from_identify, identify_push_handler_for, @@ -29,6 +33,9 @@ from libp2p.peer.peerinfo import ( from tests.utils.factories import ( host_pair_factory, ) +from tests.utils.utils import ( + create_mock_connections, +) logger = logging.getLogger("libp2p.identity.identify-push-test") @@ -175,6 +182,7 @@ async def test_identify_push_to_peers(security_protocol): host_c = new_host(key_pair=key_pair_c) # Set up the identify/push handlers + host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a)) host_b.set_stream_handler(ID_PUSH, identify_push_handler_for(host_b)) host_c.set_stream_handler(ID_PUSH, identify_push_handler_for(host_c)) @@ -204,6 +212,20 @@ async def test_identify_push_to_peers(security_protocol): # Check that the peer is in the peerstore assert peer_id_a in peerstore_c.peer_ids() + # Test for push_identify to only connected peers and not all peers + # Disconnect a from c. + await host_c.disconnect(host_a.get_id()) + + await push_identify_to_peers(host_c) + + # Wait a bit for the push to complete + await trio.sleep(0.1) + + # Check that host_a's peerstore has not been updated with host_c's info + assert host_c.get_id() not in host_a.get_peerstore().peer_ids() + # Check that host_b's peerstore has been updated with host_c's info + assert host_c.get_id() in host_b.get_peerstore().peer_ids() + @pytest.mark.trio async def test_push_identify_to_peers_with_explicit_params(security_protocol): @@ -412,3 +434,72 @@ async def test_partial_update_peerstore_from_identify(security_protocol): host_a_public_key = host_a.get_public_key().serialize() peerstore_public_key = peerstore.pubkey(peer_id).serialize() assert host_a_public_key == peerstore_public_key + + +@pytest.mark.trio +async def test_push_identify_to_peers_respects_concurrency_limit(): + """ + Test bounded concurrency for the identify/push protocol to prevent + network congestion. + + This test verifies: + 1. The number of concurrent tasks executing the identify push is always + less than or equal to CONCURRENCY_LIMIT. + 2. An error is raised if concurrency exceeds the defined limit. + + It mocks `push_identify_to_peer` to simulate delay using sleep, + allowing the test to measure and assert actual concurrency behavior. + """ + state = { + "concurrency_counter": 0, + "max_observed": 0, + } + lock = trio.Lock() + + async def mock_push_identify_to_peer( + host, peer_id, observed_multiaddr=None, limit=trio.Semaphore(CONCURRENCY_LIMIT) + ) -> bool: + """ + Mock function to test concurrency by simulating an identify message. + + This function patches push_identify_to_peer for testing purpose + + Returns + ------- + bool + True if the push was successful, False otherwise. + + """ + async with limit: + async with lock: + state["concurrency_counter"] += 1 + if state["concurrency_counter"] > CONCURRENCY_LIMIT: + raise RuntimeError( + f"Concurrency limit exceeded: {state['concurrency_counter']}" + ) + state["max_observed"] = max( + state["max_observed"], state["concurrency_counter"] + ) + + logger.debug("Successfully pushed identify to peer %s", peer_id) + await trio.sleep(0.05) + + async with lock: + state["concurrency_counter"] -= 1 + + return True + + # Create a mock host. + key_pair_host = create_new_key_pair() + host = new_host(key_pair=key_pair_host) + + # Create a mock network and add mock connections to the host + host.get_network().connections = create_mock_connections() + with patch( + "libp2p.identity.identify_push.identify_push.push_identify_to_peer", + new=mock_push_identify_to_peer, + ): + await push_identify_to_peers(host) + assert state["max_observed"] <= CONCURRENCY_LIMIT, ( + f"Max concurrency observed: {state['max_observed']}" + ) diff --git a/tests/utils/utils.py b/tests/utils/utils.py new file mode 100644 index 00000000..6e23ecdd --- /dev/null +++ b/tests/utils/utils.py @@ -0,0 +1,14 @@ +from unittest.mock import ( + MagicMock, +) + + +def create_mock_connections() -> dict: + connections = {} + + for i in range(1, 31): + peer_id = f"peer-{i}" + mock_conn = MagicMock(name=f"INetConn-{i}") + connections[peer_id] = mock_conn + + return connections