Merge branch 'main' into add-ws-transport

This commit is contained in:
Manu Sheel Gupta
2025-08-18 22:00:20 +05:30
committed by GitHub
9 changed files with 421 additions and 36 deletions

View File

@ -12,13 +12,13 @@
[![Build Status](https://img.shields.io/github/actions/workflow/status/libp2p/py-libp2p/tox.yml?branch=main&label=build%20status)](https://github.com/libp2p/py-libp2p/actions/workflows/tox.yml) [![Build Status](https://img.shields.io/github/actions/workflow/status/libp2p/py-libp2p/tox.yml?branch=main&label=build%20status)](https://github.com/libp2p/py-libp2p/actions/workflows/tox.yml)
[![Docs build](https://readthedocs.org/projects/py-libp2p/badge/?version=latest)](http://py-libp2p.readthedocs.io/en/latest/?badge=latest) [![Docs build](https://readthedocs.org/projects/py-libp2p/badge/?version=latest)](http://py-libp2p.readthedocs.io/en/latest/?badge=latest)
> ⚠️ **Warning:** py-libp2p is an experimental and work-in-progress repo under development. We do not yet recommend using py-libp2p in production environments. > py-libp2p has moved beyond its experimental roots and is steadily progressing toward production readiness. The core features are stable, and were focused on refining performance, expanding protocol support, and ensuring smooth interop with other libp2p implementations. We welcome contributions and real-world usage feedback to help us reach full production maturity.
Read more in the [documentation on ReadTheDocs](https://py-libp2p.readthedocs.io/). [View the release notes](https://py-libp2p.readthedocs.io/en/latest/release_notes.html). Read more in the [documentation on ReadTheDocs](https://py-libp2p.readthedocs.io/). [View the release notes](https://py-libp2p.readthedocs.io/en/latest/release_notes.html).
## Maintainers ## Maintainers
Currently maintained by [@pacrob](https://github.com/pacrob), [@seetadev](https://github.com/seetadev) and [@dhuseby](https://github.com/dhuseby), looking for assistance! Currently maintained by [@pacrob](https://github.com/pacrob), [@seetadev](https://github.com/seetadev) and [@dhuseby](https://github.com/dhuseby). Please reach out to us for collaboration or active feedback. If you have questions, feel free to open a new [discussion](https://github.com/libp2p/py-libp2p/discussions). We are also available on the libp2p Discord — join us at #py-libp2p [sub-channel](https://discord.gg/d92MEugb).
## Feature Breakdown ## Feature Breakdown

View File

@ -8,6 +8,7 @@ from collections import (
import logging import logging
import time import time
import multihash
import trio import trio
from libp2p.abc import ( from libp2p.abc import (
@ -40,6 +41,22 @@ PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale
def peer_id_to_key(peer_id: ID) -> bytes:
"""
Convert a peer ID to a 256-bit key for routing table operations.
This normalizes all peer IDs to exactly 256 bits by hashing them with SHA-256.
:param peer_id: The peer ID to convert
:return: 32-byte (256-bit) key for routing table operations
"""
return multihash.digest(peer_id.to_bytes(), "sha2-256").digest
def key_to_int(key: bytes) -> int:
"""Convert a 256-bit key to an integer for range calculations."""
return int.from_bytes(key, byteorder="big")
class KBucket: class KBucket:
""" """
A k-bucket implementation for the Kademlia DHT. A k-bucket implementation for the Kademlia DHT.
@ -357,9 +374,24 @@ class KBucket:
True if the key is in range, False otherwise True if the key is in range, False otherwise
""" """
key_int = int.from_bytes(key, byteorder="big") key_int = key_to_int(key)
return self.min_range <= key_int < self.max_range return self.min_range <= key_int < self.max_range
def peer_id_in_range(self, peer_id: ID) -> bool:
"""
Check if a peer ID is in the range of this bucket.
params: peer_id: The peer ID to check
Returns
-------
bool
True if the peer ID is in range, False otherwise
"""
key = peer_id_to_key(peer_id)
return self.key_in_range(key)
def split(self) -> tuple["KBucket", "KBucket"]: def split(self) -> tuple["KBucket", "KBucket"]:
""" """
Split the bucket into two buckets. Split the bucket into two buckets.
@ -376,8 +408,9 @@ class KBucket:
# Redistribute peers # Redistribute peers
for peer_id, (peer_info, timestamp) in self.peers.items(): for peer_id, (peer_info, timestamp) in self.peers.items():
peer_key = int.from_bytes(peer_id.to_bytes(), byteorder="big") peer_key = peer_id_to_key(peer_id)
if peer_key < midpoint: peer_key_int = key_to_int(peer_key)
if peer_key_int < midpoint:
lower_bucket.peers[peer_id] = (peer_info, timestamp) lower_bucket.peers[peer_id] = (peer_info, timestamp)
else: else:
upper_bucket.peers[peer_id] = (peer_info, timestamp) upper_bucket.peers[peer_id] = (peer_info, timestamp)
@ -458,7 +491,38 @@ class RoutingTable:
success = await bucket.add_peer(peer_info) success = await bucket.add_peer(peer_info)
if success: if success:
logger.debug(f"Successfully added peer {peer_id} to routing table") logger.debug(f"Successfully added peer {peer_id} to routing table")
return success return True
# If bucket is full and couldn't add peer, try splitting the bucket
# Only split if the bucket contains our Peer ID
if self._should_split_bucket(bucket):
logger.debug(
f"Bucket is full, attempting to split bucket for peer {peer_id}"
)
split_success = self._split_bucket(bucket)
if split_success:
# After splitting,
# find the appropriate bucket for the peer and try to add it
target_bucket = self.find_bucket(peer_info.peer_id)
success = await target_bucket.add_peer(peer_info)
if success:
logger.debug(
f"Successfully added peer {peer_id} after bucket split"
)
return True
else:
logger.debug(
f"Failed to add peer {peer_id} even after bucket split"
)
return False
else:
logger.debug(f"Failed to split bucket for peer {peer_id}")
return False
else:
logger.debug(
f"Bucket is full and cannot be split, peer {peer_id} not added"
)
return False
except Exception as e: except Exception as e:
logger.debug(f"Error adding peer {peer_obj} to routing table: {e}") logger.debug(f"Error adding peer {peer_obj} to routing table: {e}")
@ -480,9 +544,9 @@ class RoutingTable:
def find_bucket(self, peer_id: ID) -> KBucket: def find_bucket(self, peer_id: ID) -> KBucket:
""" """
Find the bucket that would contain the given peer ID or PeerInfo. Find the bucket that would contain the given peer ID.
:param peer_obj: Either a peer ID or a PeerInfo object :param peer_id: The peer ID to find a bucket for
Returns Returns
------- -------
@ -490,7 +554,7 @@ class RoutingTable:
""" """
for bucket in self.buckets: for bucket in self.buckets:
if bucket.key_in_range(peer_id.to_bytes()): if bucket.peer_id_in_range(peer_id):
return bucket return bucket
return self.buckets[0] return self.buckets[0]
@ -513,7 +577,11 @@ class RoutingTable:
all_peers.extend(bucket.peer_ids()) all_peers.extend(bucket.peer_ids())
# Sort by XOR distance to the key # Sort by XOR distance to the key
all_peers.sort(key=lambda p: xor_distance(p.to_bytes(), key)) def distance_to_key(peer_id: ID) -> int:
peer_key = peer_id_to_key(peer_id)
return xor_distance(peer_key, key)
all_peers.sort(key=distance_to_key)
return all_peers[:count] return all_peers[:count]
@ -591,6 +659,20 @@ class RoutingTable:
stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds)) stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds))
return stale_peers return stale_peers
def get_peer_infos(self) -> list[PeerInfo]:
"""
Get all PeerInfo objects in the routing table.
Returns
-------
List[PeerInfo]: List of all PeerInfo objects
"""
peer_infos = []
for bucket in self.buckets:
peer_infos.extend(bucket.peer_infos())
return peer_infos
def cleanup_routing_table(self) -> None: def cleanup_routing_table(self) -> None:
""" """
Cleanup the routing table by removing all data. Cleanup the routing table by removing all data.
@ -598,3 +680,66 @@ class RoutingTable:
""" """
self.buckets = [KBucket(self.host, BUCKET_SIZE)] self.buckets = [KBucket(self.host, BUCKET_SIZE)]
logger.info("Routing table cleaned up, all data removed.") logger.info("Routing table cleaned up, all data removed.")
def _should_split_bucket(self, bucket: KBucket) -> bool:
"""
Check if a bucket should be split according to Kademlia rules.
:param bucket: The bucket to check
:return: True if the bucket should be split
"""
# Check if we've exceeded maximum buckets
if len(self.buckets) >= MAXIMUM_BUCKETS:
logger.debug("Maximum number of buckets reached, cannot split")
return False
# Check if the bucket contains our local ID
local_key = peer_id_to_key(self.local_id)
local_key_int = key_to_int(local_key)
contains_local_id = bucket.min_range <= local_key_int < bucket.max_range
logger.debug(
f"Bucket range: {bucket.min_range} - {bucket.max_range}, "
f"local_key_int: {local_key_int}, contains_local: {contains_local_id}"
)
return contains_local_id
def _split_bucket(self, bucket: KBucket) -> bool:
"""
Split a bucket into two buckets.
:param bucket: The bucket to split
:return: True if the bucket was successfully split
"""
try:
# Find the bucket index
bucket_index = self.buckets.index(bucket)
logger.debug(f"Splitting bucket at index {bucket_index}")
# Split the bucket
lower_bucket, upper_bucket = bucket.split()
# Replace the original bucket with the two new buckets
self.buckets[bucket_index] = lower_bucket
self.buckets.insert(bucket_index + 1, upper_bucket)
logger.debug(
f"Bucket split successful. New bucket count: {len(self.buckets)}"
)
logger.debug(
f"Lower bucket range: "
f"{lower_bucket.min_range} - {lower_bucket.max_range}, "
f"peers: {lower_bucket.size()}"
)
logger.debug(
f"Upper bucket range: "
f"{upper_bucket.min_range} - {upper_bucket.max_range}, "
f"peers: {upper_bucket.size()}"
)
return True
except Exception as e:
logger.error(f"Error splitting bucket: {e}")
return False

View File

@ -23,7 +23,8 @@ if TYPE_CHECKING:
""" """
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go Reference: https://github.com/libp2p/go-libp2p-swarm/blob/
04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go
""" """
@ -43,6 +44,21 @@ class SwarmConn(INetConn):
self.streams = set() self.streams = set()
self.event_closed = trio.Event() self.event_closed = trio.Event()
self.event_started = trio.Event() self.event_started = trio.Event()
# Provide back-references/hooks expected by NetStream
try:
setattr(self.muxed_conn, "swarm", self.swarm)
# NetStream expects an awaitable remove_stream hook
async def _remove_stream_hook(stream: NetStream) -> None:
self.remove_stream(stream)
setattr(self.muxed_conn, "remove_stream", _remove_stream_hook)
except Exception as e:
logging.warning(
f"Failed to set optional conveniences on muxed_conn "
f"for peer {muxed_conn.peer_id}: {e}"
)
# optional conveniences
if hasattr(muxed_conn, "on_close"): if hasattr(muxed_conn, "on_close"):
logging.debug(f"Setting on_close for peer {muxed_conn.peer_id}") logging.debug(f"Setting on_close for peer {muxed_conn.peer_id}")
setattr(muxed_conn, "on_close", self._on_muxed_conn_closed) setattr(muxed_conn, "on_close", self._on_muxed_conn_closed)

View File

@ -1,3 +1,7 @@
from collections.abc import (
Awaitable,
Callable,
)
import logging import logging
from multiaddr import ( from multiaddr import (
@ -333,8 +337,16 @@ class Swarm(Service, INetworkService):
# Close all listeners # Close all listeners
if hasattr(self, "listeners"): if hasattr(self, "listeners"):
for listener in self.listeners.values(): for maddr_str, listener in self.listeners.items():
await listener.close() await listener.close()
# Notify about listener closure
try:
multiaddr = Multiaddr(maddr_str)
await self.notify_listen_close(multiaddr)
except Exception as e:
logger.warning(
f"Failed to notify listen_close for {maddr_str}: {e}"
)
self.listeners.clear() self.listeners.clear()
# Close the transport if it exists and has a close method # Close the transport if it exists and has a close method
@ -418,7 +430,17 @@ class Swarm(Service, INetworkService):
nursery.start_soon(notifee.listen, self, multiaddr) nursery.start_soon(notifee.listen, self, multiaddr)
async def notify_closed_stream(self, stream: INetStream) -> None: async def notify_closed_stream(self, stream: INetStream) -> None:
raise NotImplementedError async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.closed_stream, self, stream)
async def notify_listen_close(self, multiaddr: Multiaddr) -> None: async def notify_listen_close(self, multiaddr: Multiaddr) -> None:
raise NotImplementedError async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.listen_close, self, multiaddr)
# Generic notifier used by NetStream._notify_closed
async def notify_all(self, notifier: Callable[[INotifee], Awaitable[None]]) -> None:
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifier, notifee)

View File

@ -0,0 +1,6 @@
Implement closed_stream notification in MyNotifee
- Add notify_closed_stream method to swarm notification system for proper stream lifecycle management
- Integrate remove_stream hook in SwarmConn to enable stream closure notifications
- Add comprehensive tests for closed_stream functionality in test_notify.py
- Enable stream lifecycle integration for proper cleanup and resource management

View File

@ -0,0 +1 @@
Fix kbucket splitting in routing table when full. Routing table now maintains multiple kbuckets and properly distributes peers as specified by the Kademlia DHT protocol.

View File

@ -10,8 +10,10 @@ readme = "README.md"
requires-python = ">=3.10, <4.0" requires-python = ">=3.10, <4.0"
license = { text = "MIT AND Apache-2.0" } license = { text = "MIT AND Apache-2.0" }
keywords = ["libp2p", "p2p"] keywords = ["libp2p", "p2p"]
authors = [ maintainers = [
{ name = "The Ethereum Foundation", email = "snakecharmers@ethereum.org" }, { name = "pacrob", email = "pacrob@protonmail.com" },
{ name = "Manu Sheel Gupta", email = "manu@seeta.in" },
{ name = "Dave Grantham", email = "dave@aviation.community" },
] ]
dependencies = [ dependencies = [
"base58>=1.0.3", "base58>=1.0.3",

View File

@ -226,6 +226,32 @@ class TestKBucket:
class TestRoutingTable: class TestRoutingTable:
"""Test suite for RoutingTable class.""" """Test suite for RoutingTable class."""
@pytest.mark.trio
async def test_kbucket_split_behavior(self, mock_host, local_peer_id):
"""
Test that adding more than BUCKET_SIZE peers to the routing table
triggers kbucket splitting and all peers are added.
"""
routing_table = RoutingTable(local_peer_id, mock_host)
num_peers = BUCKET_SIZE + 5
peer_ids = []
for i in range(num_peers):
key_pair = create_new_key_pair()
peer_id = ID.from_pubkey(key_pair.public_key)
peer_info = PeerInfo(peer_id, [Multiaddr(f"/ip4/127.0.0.1/tcp/{9000 + i}")])
peer_ids.append(peer_id)
added = await routing_table.add_peer(peer_info)
assert added, f"Peer {peer_id} should be added"
assert len(routing_table.buckets) > 1, "KBucket splitting did not occur"
for pid in peer_ids:
assert routing_table.peer_in_table(pid), f"Peer {pid} not found after split"
all_peer_ids = routing_table.get_peer_ids()
assert set(peer_ids).issubset(set(all_peer_ids)), (
"Not all peers present after split"
)
@pytest.fixture @pytest.fixture
def mock_host(self): def mock_host(self):
"""Create a mock host for testing.""" """Create a mock host for testing."""

View File

@ -5,11 +5,12 @@ the stream passed into opened_stream is correct.
Note: Listen event does not get hit because MyNotifee is passed Note: Listen event does not get hit because MyNotifee is passed
into network after network has already started listening into network after network has already started listening
TODO: Add tests for closed_stream, listen_close when those Note: ClosedStream events are processed asynchronously and may not be
features are implemented in swarm immediately available due to the rapid nature of operations
""" """
import enum import enum
from unittest.mock import Mock
import pytest import pytest
from multiaddr import Multiaddr from multiaddr import Multiaddr
@ -29,11 +30,11 @@ from tests.utils.factories import (
class Event(enum.Enum): class Event(enum.Enum):
OpenedStream = 0 OpenedStream = 0
ClosedStream = 1 # Not implemented ClosedStream = 1
Connected = 2 Connected = 2
Disconnected = 3 Disconnected = 3
Listen = 4 Listen = 4
ListenClose = 5 # Not implemented ListenClose = 5
class MyNotifee(INotifee): class MyNotifee(INotifee):
@ -44,8 +45,11 @@ class MyNotifee(INotifee):
self.events.append(Event.OpenedStream) self.events.append(Event.OpenedStream)
async def closed_stream(self, network: INetwork, stream: INetStream) -> None: async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
# TODO: It is not implemented yet. if network is None:
pass raise ValueError("network parameter cannot be None")
if stream is None:
raise ValueError("stream parameter cannot be None")
self.events.append(Event.ClosedStream)
async def connected(self, network: INetwork, conn: INetConn) -> None: async def connected(self, network: INetwork, conn: INetConn) -> None:
self.events.append(Event.Connected) self.events.append(Event.Connected)
@ -57,8 +61,11 @@ class MyNotifee(INotifee):
self.events.append(Event.Listen) self.events.append(Event.Listen)
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None: async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
# TODO: It is not implemented yet. if network is None:
pass raise ValueError("network parameter cannot be None")
if multiaddr is None:
raise ValueError("multiaddr parameter cannot be None")
self.events.append(Event.ListenClose)
@pytest.mark.trio @pytest.mark.trio
@ -103,28 +110,188 @@ async def test_notify(security_protocol):
# Wait for events # Wait for events
assert await wait_for_event(events_0_0, Event.Connected, 1.0) assert await wait_for_event(events_0_0, Event.Connected, 1.0)
assert await wait_for_event(events_0_0, Event.OpenedStream, 1.0) assert await wait_for_event(events_0_0, Event.OpenedStream, 1.0)
# assert await wait_for_event( assert await wait_for_event(events_0_0, Event.ClosedStream, 1.0)
# events_0_0, Event.ClosedStream, 1.0
# ) # Not implemented
assert await wait_for_event(events_0_0, Event.Disconnected, 1.0) assert await wait_for_event(events_0_0, Event.Disconnected, 1.0)
assert await wait_for_event(events_0_1, Event.Connected, 1.0) assert await wait_for_event(events_0_1, Event.Connected, 1.0)
assert await wait_for_event(events_0_1, Event.OpenedStream, 1.0) assert await wait_for_event(events_0_1, Event.OpenedStream, 1.0)
# assert await wait_for_event( assert await wait_for_event(events_0_1, Event.ClosedStream, 1.0)
# events_0_1, Event.ClosedStream, 1.0
# ) # Not implemented
assert await wait_for_event(events_0_1, Event.Disconnected, 1.0) assert await wait_for_event(events_0_1, Event.Disconnected, 1.0)
assert await wait_for_event(events_1_0, Event.Connected, 1.0) assert await wait_for_event(events_1_0, Event.Connected, 1.0)
assert await wait_for_event(events_1_0, Event.OpenedStream, 1.0) assert await wait_for_event(events_1_0, Event.OpenedStream, 1.0)
# assert await wait_for_event( assert await wait_for_event(events_1_0, Event.ClosedStream, 1.0)
# events_1_0, Event.ClosedStream, 1.0
# ) # Not implemented
assert await wait_for_event(events_1_0, Event.Disconnected, 1.0) assert await wait_for_event(events_1_0, Event.Disconnected, 1.0)
assert await wait_for_event(events_1_1, Event.Connected, 1.0) assert await wait_for_event(events_1_1, Event.Connected, 1.0)
assert await wait_for_event(events_1_1, Event.OpenedStream, 1.0) assert await wait_for_event(events_1_1, Event.OpenedStream, 1.0)
# assert await wait_for_event( assert await wait_for_event(events_1_1, Event.ClosedStream, 1.0)
# events_1_1, Event.ClosedStream, 1.0
# ) # Not implemented
assert await wait_for_event(events_1_1, Event.Disconnected, 1.0) assert await wait_for_event(events_1_1, Event.Disconnected, 1.0)
# Note: ListenClose events are triggered when swarm closes during cleanup
# The test framework automatically closes listeners, triggering ListenClose
# notifications
async def wait_for_event(events_list, event, timeout=1.0):
"""Helper to wait for a specific event to appear in the events list."""
with trio.move_on_after(timeout):
while event not in events_list:
await trio.sleep(0.01)
return True
return False
@pytest.mark.trio
async def test_notify_with_closed_stream_and_listen_close():
"""Test that closed_stream and listen_close events are properly triggered."""
# Event lists for notifees
events_0 = []
events_1 = []
# Create two swarms
async with SwarmFactory.create_batch_and_listen(2) as swarms:
# Register notifees
notifee_0 = MyNotifee(events_0)
notifee_1 = MyNotifee(events_1)
swarms[0].register_notifee(notifee_0)
swarms[1].register_notifee(notifee_1)
# Connect swarms
await connect_swarm(swarms[0], swarms[1])
# Create and close a stream to trigger closed_stream event
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
await stream.close()
# Note: Events are processed asynchronously and may not be immediately available
# due to the rapid nature of operations
@pytest.mark.trio
async def test_notify_edge_cases():
"""Test edge cases for notify system."""
events = []
async with SwarmFactory.create_batch_and_listen(2) as swarms:
notifee = MyNotifee(events)
swarms[0].register_notifee(notifee)
# Connect swarms first
await connect_swarm(swarms[0], swarms[1])
# Test 1: Multiple rapid stream operations
streams = []
for _ in range(5):
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
streams.append(stream)
# Close all streams rapidly
for stream in streams:
await stream.close()
@pytest.mark.trio
async def test_my_notifee_error_handling():
"""Test error handling for invalid parameters in MyNotifee methods."""
events = []
notifee = MyNotifee(events)
# Mock objects for testing
mock_network = Mock(spec=INetwork)
mock_stream = Mock(spec=INetStream)
mock_multiaddr = Mock(spec=Multiaddr)
# Test closed_stream with None parameters
with pytest.raises(ValueError, match="network parameter cannot be None"):
await notifee.closed_stream(None, mock_stream) # type: ignore
with pytest.raises(ValueError, match="stream parameter cannot be None"):
await notifee.closed_stream(mock_network, None) # type: ignore
# Test listen_close with None parameters
with pytest.raises(ValueError, match="network parameter cannot be None"):
await notifee.listen_close(None, mock_multiaddr) # type: ignore
with pytest.raises(ValueError, match="multiaddr parameter cannot be None"):
await notifee.listen_close(mock_network, None) # type: ignore
# Verify no events were recorded due to errors
assert len(events) == 0
@pytest.mark.trio
async def test_rapid_stream_operations():
"""Test rapid stream open/close operations."""
events_0 = []
events_1 = []
async with SwarmFactory.create_batch_and_listen(2) as swarms:
notifee_0 = MyNotifee(events_0)
notifee_1 = MyNotifee(events_1)
swarms[0].register_notifee(notifee_0)
swarms[1].register_notifee(notifee_1)
# Connect swarms
await connect_swarm(swarms[0], swarms[1])
# Rapidly create and close multiple streams
streams = []
for _ in range(3):
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
streams.append(stream)
# Close all streams immediately
for stream in streams:
await stream.close()
# Verify OpenedStream events are recorded
assert events_0.count(Event.OpenedStream) == 3
assert events_1.count(Event.OpenedStream) == 3
# Close peer to trigger disconnection events
await swarms[0].close_peer(swarms[1].get_peer_id())
@pytest.mark.trio
async def test_concurrent_stream_operations():
"""Test concurrent stream operations using trio nursery."""
events_0 = []
events_1 = []
async with SwarmFactory.create_batch_and_listen(2) as swarms:
notifee_0 = MyNotifee(events_0)
notifee_1 = MyNotifee(events_1)
swarms[0].register_notifee(notifee_0)
swarms[1].register_notifee(notifee_1)
# Connect swarms
await connect_swarm(swarms[0], swarms[1])
async def create_and_close_stream():
"""Create and immediately close a stream."""
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
await stream.close()
# Run multiple stream operations concurrently
async with trio.open_nursery() as nursery:
for _ in range(4):
nursery.start_soon(create_and_close_stream)
# Verify some OpenedStream events are recorded
# (concurrent operations may not all succeed)
opened_count_0 = events_0.count(Event.OpenedStream)
opened_count_1 = events_1.count(Event.OpenedStream)
assert opened_count_0 > 0, (
f"Expected some OpenedStream events, got {opened_count_0}"
)
assert opened_count_1 > 0, (
f"Expected some OpenedStream events, got {opened_count_1}"
)
# Close peer to trigger disconnection events
await swarms[0].close_peer(swarms[1].get_peer_id())