mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
Merge branch 'main' into fix_pubsub_msg_id_type_inconsistency
This commit is contained in:
@ -8,6 +8,7 @@ from collections import (
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import multihash
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p.abc import (
|
from libp2p.abc import (
|
||||||
@ -40,6 +41,22 @@ PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
|
|||||||
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale
|
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale
|
||||||
|
|
||||||
|
|
||||||
|
def peer_id_to_key(peer_id: ID) -> bytes:
|
||||||
|
"""
|
||||||
|
Convert a peer ID to a 256-bit key for routing table operations.
|
||||||
|
This normalizes all peer IDs to exactly 256 bits by hashing them with SHA-256.
|
||||||
|
|
||||||
|
:param peer_id: The peer ID to convert
|
||||||
|
:return: 32-byte (256-bit) key for routing table operations
|
||||||
|
"""
|
||||||
|
return multihash.digest(peer_id.to_bytes(), "sha2-256").digest
|
||||||
|
|
||||||
|
|
||||||
|
def key_to_int(key: bytes) -> int:
|
||||||
|
"""Convert a 256-bit key to an integer for range calculations."""
|
||||||
|
return int.from_bytes(key, byteorder="big")
|
||||||
|
|
||||||
|
|
||||||
class KBucket:
|
class KBucket:
|
||||||
"""
|
"""
|
||||||
A k-bucket implementation for the Kademlia DHT.
|
A k-bucket implementation for the Kademlia DHT.
|
||||||
@ -357,9 +374,24 @@ class KBucket:
|
|||||||
True if the key is in range, False otherwise
|
True if the key is in range, False otherwise
|
||||||
|
|
||||||
"""
|
"""
|
||||||
key_int = int.from_bytes(key, byteorder="big")
|
key_int = key_to_int(key)
|
||||||
return self.min_range <= key_int < self.max_range
|
return self.min_range <= key_int < self.max_range
|
||||||
|
|
||||||
|
def peer_id_in_range(self, peer_id: ID) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a peer ID is in the range of this bucket.
|
||||||
|
|
||||||
|
params: peer_id: The peer ID to check
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the peer ID is in range, False otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
key = peer_id_to_key(peer_id)
|
||||||
|
return self.key_in_range(key)
|
||||||
|
|
||||||
def split(self) -> tuple["KBucket", "KBucket"]:
|
def split(self) -> tuple["KBucket", "KBucket"]:
|
||||||
"""
|
"""
|
||||||
Split the bucket into two buckets.
|
Split the bucket into two buckets.
|
||||||
@ -376,8 +408,9 @@ class KBucket:
|
|||||||
|
|
||||||
# Redistribute peers
|
# Redistribute peers
|
||||||
for peer_id, (peer_info, timestamp) in self.peers.items():
|
for peer_id, (peer_info, timestamp) in self.peers.items():
|
||||||
peer_key = int.from_bytes(peer_id.to_bytes(), byteorder="big")
|
peer_key = peer_id_to_key(peer_id)
|
||||||
if peer_key < midpoint:
|
peer_key_int = key_to_int(peer_key)
|
||||||
|
if peer_key_int < midpoint:
|
||||||
lower_bucket.peers[peer_id] = (peer_info, timestamp)
|
lower_bucket.peers[peer_id] = (peer_info, timestamp)
|
||||||
else:
|
else:
|
||||||
upper_bucket.peers[peer_id] = (peer_info, timestamp)
|
upper_bucket.peers[peer_id] = (peer_info, timestamp)
|
||||||
@ -458,7 +491,38 @@ class RoutingTable:
|
|||||||
success = await bucket.add_peer(peer_info)
|
success = await bucket.add_peer(peer_info)
|
||||||
if success:
|
if success:
|
||||||
logger.debug(f"Successfully added peer {peer_id} to routing table")
|
logger.debug(f"Successfully added peer {peer_id} to routing table")
|
||||||
return success
|
return True
|
||||||
|
|
||||||
|
# If bucket is full and couldn't add peer, try splitting the bucket
|
||||||
|
# Only split if the bucket contains our Peer ID
|
||||||
|
if self._should_split_bucket(bucket):
|
||||||
|
logger.debug(
|
||||||
|
f"Bucket is full, attempting to split bucket for peer {peer_id}"
|
||||||
|
)
|
||||||
|
split_success = self._split_bucket(bucket)
|
||||||
|
if split_success:
|
||||||
|
# After splitting,
|
||||||
|
# find the appropriate bucket for the peer and try to add it
|
||||||
|
target_bucket = self.find_bucket(peer_info.peer_id)
|
||||||
|
success = await target_bucket.add_peer(peer_info)
|
||||||
|
if success:
|
||||||
|
logger.debug(
|
||||||
|
f"Successfully added peer {peer_id} after bucket split"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"Failed to add peer {peer_id} even after bucket split"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.debug(f"Failed to split bucket for peer {peer_id}")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"Bucket is full and cannot be split, peer {peer_id} not added"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error adding peer {peer_obj} to routing table: {e}")
|
logger.debug(f"Error adding peer {peer_obj} to routing table: {e}")
|
||||||
@ -480,9 +544,9 @@ class RoutingTable:
|
|||||||
|
|
||||||
def find_bucket(self, peer_id: ID) -> KBucket:
|
def find_bucket(self, peer_id: ID) -> KBucket:
|
||||||
"""
|
"""
|
||||||
Find the bucket that would contain the given peer ID or PeerInfo.
|
Find the bucket that would contain the given peer ID.
|
||||||
|
|
||||||
:param peer_obj: Either a peer ID or a PeerInfo object
|
:param peer_id: The peer ID to find a bucket for
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -490,7 +554,7 @@ class RoutingTable:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
for bucket in self.buckets:
|
for bucket in self.buckets:
|
||||||
if bucket.key_in_range(peer_id.to_bytes()):
|
if bucket.peer_id_in_range(peer_id):
|
||||||
return bucket
|
return bucket
|
||||||
|
|
||||||
return self.buckets[0]
|
return self.buckets[0]
|
||||||
@ -513,7 +577,11 @@ class RoutingTable:
|
|||||||
all_peers.extend(bucket.peer_ids())
|
all_peers.extend(bucket.peer_ids())
|
||||||
|
|
||||||
# Sort by XOR distance to the key
|
# Sort by XOR distance to the key
|
||||||
all_peers.sort(key=lambda p: xor_distance(p.to_bytes(), key))
|
def distance_to_key(peer_id: ID) -> int:
|
||||||
|
peer_key = peer_id_to_key(peer_id)
|
||||||
|
return xor_distance(peer_key, key)
|
||||||
|
|
||||||
|
all_peers.sort(key=distance_to_key)
|
||||||
|
|
||||||
return all_peers[:count]
|
return all_peers[:count]
|
||||||
|
|
||||||
@ -591,6 +659,20 @@ class RoutingTable:
|
|||||||
stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds))
|
stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds))
|
||||||
return stale_peers
|
return stale_peers
|
||||||
|
|
||||||
|
def get_peer_infos(self) -> list[PeerInfo]:
|
||||||
|
"""
|
||||||
|
Get all PeerInfo objects in the routing table.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[PeerInfo]: List of all PeerInfo objects
|
||||||
|
|
||||||
|
"""
|
||||||
|
peer_infos = []
|
||||||
|
for bucket in self.buckets:
|
||||||
|
peer_infos.extend(bucket.peer_infos())
|
||||||
|
return peer_infos
|
||||||
|
|
||||||
def cleanup_routing_table(self) -> None:
|
def cleanup_routing_table(self) -> None:
|
||||||
"""
|
"""
|
||||||
Cleanup the routing table by removing all data.
|
Cleanup the routing table by removing all data.
|
||||||
@ -598,3 +680,66 @@ class RoutingTable:
|
|||||||
"""
|
"""
|
||||||
self.buckets = [KBucket(self.host, BUCKET_SIZE)]
|
self.buckets = [KBucket(self.host, BUCKET_SIZE)]
|
||||||
logger.info("Routing table cleaned up, all data removed.")
|
logger.info("Routing table cleaned up, all data removed.")
|
||||||
|
|
||||||
|
def _should_split_bucket(self, bucket: KBucket) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a bucket should be split according to Kademlia rules.
|
||||||
|
|
||||||
|
:param bucket: The bucket to check
|
||||||
|
:return: True if the bucket should be split
|
||||||
|
"""
|
||||||
|
# Check if we've exceeded maximum buckets
|
||||||
|
if len(self.buckets) >= MAXIMUM_BUCKETS:
|
||||||
|
logger.debug("Maximum number of buckets reached, cannot split")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the bucket contains our local ID
|
||||||
|
local_key = peer_id_to_key(self.local_id)
|
||||||
|
local_key_int = key_to_int(local_key)
|
||||||
|
contains_local_id = bucket.min_range <= local_key_int < bucket.max_range
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Bucket range: {bucket.min_range} - {bucket.max_range}, "
|
||||||
|
f"local_key_int: {local_key_int}, contains_local: {contains_local_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return contains_local_id
|
||||||
|
|
||||||
|
def _split_bucket(self, bucket: KBucket) -> bool:
|
||||||
|
"""
|
||||||
|
Split a bucket into two buckets.
|
||||||
|
|
||||||
|
:param bucket: The bucket to split
|
||||||
|
:return: True if the bucket was successfully split
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Find the bucket index
|
||||||
|
bucket_index = self.buckets.index(bucket)
|
||||||
|
logger.debug(f"Splitting bucket at index {bucket_index}")
|
||||||
|
|
||||||
|
# Split the bucket
|
||||||
|
lower_bucket, upper_bucket = bucket.split()
|
||||||
|
|
||||||
|
# Replace the original bucket with the two new buckets
|
||||||
|
self.buckets[bucket_index] = lower_bucket
|
||||||
|
self.buckets.insert(bucket_index + 1, upper_bucket)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Bucket split successful. New bucket count: {len(self.buckets)}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Lower bucket range: "
|
||||||
|
f"{lower_bucket.min_range} - {lower_bucket.max_range}, "
|
||||||
|
f"peers: {lower_bucket.size()}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Upper bucket range: "
|
||||||
|
f"{upper_bucket.min_range} - {upper_bucket.max_range}, "
|
||||||
|
f"peers: {upper_bucket.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error splitting bucket: {e}")
|
||||||
|
return False
|
||||||
|
|||||||
@ -330,8 +330,16 @@ class Swarm(Service, INetworkService):
|
|||||||
|
|
||||||
# Close all listeners
|
# Close all listeners
|
||||||
if hasattr(self, "listeners"):
|
if hasattr(self, "listeners"):
|
||||||
for listener in self.listeners.values():
|
for maddr_str, listener in self.listeners.items():
|
||||||
await listener.close()
|
await listener.close()
|
||||||
|
# Notify about listener closure
|
||||||
|
try:
|
||||||
|
multiaddr = Multiaddr(maddr_str)
|
||||||
|
await self.notify_listen_close(multiaddr)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to notify listen_close for {maddr_str}: {e}"
|
||||||
|
)
|
||||||
self.listeners.clear()
|
self.listeners.clear()
|
||||||
|
|
||||||
# Close the transport if it exists and has a close method
|
# Close the transport if it exists and has a close method
|
||||||
@ -420,7 +428,9 @@ class Swarm(Service, INetworkService):
|
|||||||
nursery.start_soon(notifee.closed_stream, self, stream)
|
nursery.start_soon(notifee.closed_stream, self, stream)
|
||||||
|
|
||||||
async def notify_listen_close(self, multiaddr: Multiaddr) -> None:
|
async def notify_listen_close(self, multiaddr: Multiaddr) -> None:
|
||||||
raise NotImplementedError
|
async with trio.open_nursery() as nursery:
|
||||||
|
for notifee in self.notifees:
|
||||||
|
nursery.start_soon(notifee.listen_close, self, multiaddr)
|
||||||
|
|
||||||
# Generic notifier used by NetStream._notify_closed
|
# Generic notifier used by NetStream._notify_closed
|
||||||
async def notify_all(self, notifier: Callable[[INotifee], Awaitable[None]]) -> None:
|
async def notify_all(self, notifier: Callable[[INotifee], Awaitable[None]]) -> None:
|
||||||
|
|||||||
1
newsfragments/846.bugfix.rst
Normal file
1
newsfragments/846.bugfix.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix kbucket splitting in routing table when full. Routing table now maintains multiple kbuckets and properly distributes peers as specified by the Kademlia DHT protocol.
|
||||||
@ -226,6 +226,32 @@ class TestKBucket:
|
|||||||
class TestRoutingTable:
|
class TestRoutingTable:
|
||||||
"""Test suite for RoutingTable class."""
|
"""Test suite for RoutingTable class."""
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_kbucket_split_behavior(self, mock_host, local_peer_id):
|
||||||
|
"""
|
||||||
|
Test that adding more than BUCKET_SIZE peers to the routing table
|
||||||
|
triggers kbucket splitting and all peers are added.
|
||||||
|
"""
|
||||||
|
routing_table = RoutingTable(local_peer_id, mock_host)
|
||||||
|
|
||||||
|
num_peers = BUCKET_SIZE + 5
|
||||||
|
peer_ids = []
|
||||||
|
for i in range(num_peers):
|
||||||
|
key_pair = create_new_key_pair()
|
||||||
|
peer_id = ID.from_pubkey(key_pair.public_key)
|
||||||
|
peer_info = PeerInfo(peer_id, [Multiaddr(f"/ip4/127.0.0.1/tcp/{9000 + i}")])
|
||||||
|
peer_ids.append(peer_id)
|
||||||
|
added = await routing_table.add_peer(peer_info)
|
||||||
|
assert added, f"Peer {peer_id} should be added"
|
||||||
|
|
||||||
|
assert len(routing_table.buckets) > 1, "KBucket splitting did not occur"
|
||||||
|
for pid in peer_ids:
|
||||||
|
assert routing_table.peer_in_table(pid), f"Peer {pid} not found after split"
|
||||||
|
all_peer_ids = routing_table.get_peer_ids()
|
||||||
|
assert set(peer_ids).issubset(set(all_peer_ids)), (
|
||||||
|
"Not all peers present after split"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_host(self):
|
def mock_host(self):
|
||||||
"""Create a mock host for testing."""
|
"""Create a mock host for testing."""
|
||||||
|
|||||||
@ -5,11 +5,12 @@ the stream passed into opened_stream is correct.
|
|||||||
Note: Listen event does not get hit because MyNotifee is passed
|
Note: Listen event does not get hit because MyNotifee is passed
|
||||||
into network after network has already started listening
|
into network after network has already started listening
|
||||||
|
|
||||||
TODO: Add tests for closed_stream, listen_close when those
|
Note: ClosedStream events are processed asynchronously and may not be
|
||||||
features are implemented in swarm
|
immediately available due to the rapid nature of operations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from multiaddr import Multiaddr
|
from multiaddr import Multiaddr
|
||||||
@ -29,11 +30,11 @@ from tests.utils.factories import (
|
|||||||
|
|
||||||
class Event(enum.Enum):
|
class Event(enum.Enum):
|
||||||
OpenedStream = 0
|
OpenedStream = 0
|
||||||
ClosedStream = 1 # Not implemented
|
ClosedStream = 1
|
||||||
Connected = 2
|
Connected = 2
|
||||||
Disconnected = 3
|
Disconnected = 3
|
||||||
Listen = 4
|
Listen = 4
|
||||||
ListenClose = 5 # Not implemented
|
ListenClose = 5
|
||||||
|
|
||||||
|
|
||||||
class MyNotifee(INotifee):
|
class MyNotifee(INotifee):
|
||||||
@ -60,8 +61,11 @@ class MyNotifee(INotifee):
|
|||||||
self.events.append(Event.Listen)
|
self.events.append(Event.Listen)
|
||||||
|
|
||||||
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||||
# TODO: It is not implemented yet.
|
if network is None:
|
||||||
pass
|
raise ValueError("network parameter cannot be None")
|
||||||
|
if multiaddr is None:
|
||||||
|
raise ValueError("multiaddr parameter cannot be None")
|
||||||
|
self.events.append(Event.ListenClose)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
@ -123,3 +127,171 @@ async def test_notify(security_protocol):
|
|||||||
assert await wait_for_event(events_1_1, Event.OpenedStream, 1.0)
|
assert await wait_for_event(events_1_1, Event.OpenedStream, 1.0)
|
||||||
assert await wait_for_event(events_1_1, Event.ClosedStream, 1.0)
|
assert await wait_for_event(events_1_1, Event.ClosedStream, 1.0)
|
||||||
assert await wait_for_event(events_1_1, Event.Disconnected, 1.0)
|
assert await wait_for_event(events_1_1, Event.Disconnected, 1.0)
|
||||||
|
|
||||||
|
# Note: ListenClose events are triggered when swarm closes during cleanup
|
||||||
|
# The test framework automatically closes listeners, triggering ListenClose
|
||||||
|
# notifications
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_event(events_list, event, timeout=1.0):
|
||||||
|
"""Helper to wait for a specific event to appear in the events list."""
|
||||||
|
with trio.move_on_after(timeout):
|
||||||
|
while event not in events_list:
|
||||||
|
await trio.sleep(0.01)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_notify_with_closed_stream_and_listen_close():
|
||||||
|
"""Test that closed_stream and listen_close events are properly triggered."""
|
||||||
|
# Event lists for notifees
|
||||||
|
events_0 = []
|
||||||
|
events_1 = []
|
||||||
|
|
||||||
|
# Create two swarms
|
||||||
|
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||||
|
# Register notifees
|
||||||
|
notifee_0 = MyNotifee(events_0)
|
||||||
|
notifee_1 = MyNotifee(events_1)
|
||||||
|
|
||||||
|
swarms[0].register_notifee(notifee_0)
|
||||||
|
swarms[1].register_notifee(notifee_1)
|
||||||
|
|
||||||
|
# Connect swarms
|
||||||
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
|
|
||||||
|
# Create and close a stream to trigger closed_stream event
|
||||||
|
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
# Note: Events are processed asynchronously and may not be immediately available
|
||||||
|
# due to the rapid nature of operations
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_notify_edge_cases():
|
||||||
|
"""Test edge cases for notify system."""
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||||
|
notifee = MyNotifee(events)
|
||||||
|
swarms[0].register_notifee(notifee)
|
||||||
|
|
||||||
|
# Connect swarms first
|
||||||
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
|
|
||||||
|
# Test 1: Multiple rapid stream operations
|
||||||
|
streams = []
|
||||||
|
for _ in range(5):
|
||||||
|
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||||
|
streams.append(stream)
|
||||||
|
|
||||||
|
# Close all streams rapidly
|
||||||
|
for stream in streams:
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_my_notifee_error_handling():
|
||||||
|
"""Test error handling for invalid parameters in MyNotifee methods."""
|
||||||
|
events = []
|
||||||
|
notifee = MyNotifee(events)
|
||||||
|
|
||||||
|
# Mock objects for testing
|
||||||
|
mock_network = Mock(spec=INetwork)
|
||||||
|
mock_stream = Mock(spec=INetStream)
|
||||||
|
mock_multiaddr = Mock(spec=Multiaddr)
|
||||||
|
|
||||||
|
# Test closed_stream with None parameters
|
||||||
|
with pytest.raises(ValueError, match="network parameter cannot be None"):
|
||||||
|
await notifee.closed_stream(None, mock_stream) # type: ignore
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="stream parameter cannot be None"):
|
||||||
|
await notifee.closed_stream(mock_network, None) # type: ignore
|
||||||
|
|
||||||
|
# Test listen_close with None parameters
|
||||||
|
with pytest.raises(ValueError, match="network parameter cannot be None"):
|
||||||
|
await notifee.listen_close(None, mock_multiaddr) # type: ignore
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="multiaddr parameter cannot be None"):
|
||||||
|
await notifee.listen_close(mock_network, None) # type: ignore
|
||||||
|
|
||||||
|
# Verify no events were recorded due to errors
|
||||||
|
assert len(events) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_rapid_stream_operations():
|
||||||
|
"""Test rapid stream open/close operations."""
|
||||||
|
events_0 = []
|
||||||
|
events_1 = []
|
||||||
|
|
||||||
|
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||||
|
notifee_0 = MyNotifee(events_0)
|
||||||
|
notifee_1 = MyNotifee(events_1)
|
||||||
|
|
||||||
|
swarms[0].register_notifee(notifee_0)
|
||||||
|
swarms[1].register_notifee(notifee_1)
|
||||||
|
|
||||||
|
# Connect swarms
|
||||||
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
|
|
||||||
|
# Rapidly create and close multiple streams
|
||||||
|
streams = []
|
||||||
|
for _ in range(3):
|
||||||
|
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||||
|
streams.append(stream)
|
||||||
|
|
||||||
|
# Close all streams immediately
|
||||||
|
for stream in streams:
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
# Verify OpenedStream events are recorded
|
||||||
|
assert events_0.count(Event.OpenedStream) == 3
|
||||||
|
assert events_1.count(Event.OpenedStream) == 3
|
||||||
|
|
||||||
|
# Close peer to trigger disconnection events
|
||||||
|
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_concurrent_stream_operations():
|
||||||
|
"""Test concurrent stream operations using trio nursery."""
|
||||||
|
events_0 = []
|
||||||
|
events_1 = []
|
||||||
|
|
||||||
|
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||||
|
notifee_0 = MyNotifee(events_0)
|
||||||
|
notifee_1 = MyNotifee(events_1)
|
||||||
|
|
||||||
|
swarms[0].register_notifee(notifee_0)
|
||||||
|
swarms[1].register_notifee(notifee_1)
|
||||||
|
|
||||||
|
# Connect swarms
|
||||||
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
|
|
||||||
|
async def create_and_close_stream():
|
||||||
|
"""Create and immediately close a stream."""
|
||||||
|
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
# Run multiple stream operations concurrently
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
for _ in range(4):
|
||||||
|
nursery.start_soon(create_and_close_stream)
|
||||||
|
|
||||||
|
# Verify some OpenedStream events are recorded
|
||||||
|
# (concurrent operations may not all succeed)
|
||||||
|
opened_count_0 = events_0.count(Event.OpenedStream)
|
||||||
|
opened_count_1 = events_1.count(Event.OpenedStream)
|
||||||
|
|
||||||
|
assert opened_count_0 > 0, (
|
||||||
|
f"Expected some OpenedStream events, got {opened_count_0}"
|
||||||
|
)
|
||||||
|
assert opened_count_1 > 0, (
|
||||||
|
f"Expected some OpenedStream events, got {opened_count_1}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Close peer to trigger disconnection events
|
||||||
|
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||||
|
|||||||
Reference in New Issue
Block a user