mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-11 07:30:55 +00:00
Merge branch 'main' into add-read-write-lock
This commit is contained in:
@ -72,13 +72,46 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
|
|||||||
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
|
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
|
||||||
|
|
||||||
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
|
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
|
||||||
|
format_flag = "--raw-format" if not use_varint_format else ""
|
||||||
print(
|
print(
|
||||||
f"First host listening (using {format_name} format). "
|
f"First host listening (using {format_name} format). "
|
||||||
f"Run this from another console:\n\n"
|
f"Run this from another console:\n\n"
|
||||||
f"identify-demo "
|
f"identify-demo {format_flag} -d {client_addr}\n"
|
||||||
f"-d {client_addr}\n"
|
|
||||||
)
|
)
|
||||||
print("Waiting for incoming identify request...")
|
print("Waiting for incoming identify request...")
|
||||||
|
|
||||||
|
# Add a custom handler to show connection events
|
||||||
|
async def custom_identify_handler(stream):
|
||||||
|
peer_id = stream.muxed_conn.peer_id
|
||||||
|
print(f"\n🔗 Received identify request from peer: {peer_id}")
|
||||||
|
|
||||||
|
# Show remote address in multiaddr format
|
||||||
|
try:
|
||||||
|
from libp2p.identity.identify.identify import (
|
||||||
|
_remote_address_to_multiaddr,
|
||||||
|
)
|
||||||
|
|
||||||
|
remote_address = stream.get_remote_address()
|
||||||
|
if remote_address:
|
||||||
|
observed_multiaddr = _remote_address_to_multiaddr(
|
||||||
|
remote_address
|
||||||
|
)
|
||||||
|
# Add the peer ID to create a complete multiaddr
|
||||||
|
complete_multiaddr = f"{observed_multiaddr}/p2p/{peer_id}"
|
||||||
|
print(f" Remote address: {complete_multiaddr}")
|
||||||
|
else:
|
||||||
|
print(f" Remote address: {remote_address}")
|
||||||
|
except Exception:
|
||||||
|
print(f" Remote address: {stream.get_remote_address()}")
|
||||||
|
|
||||||
|
# Call the original handler
|
||||||
|
await identify_handler(stream)
|
||||||
|
|
||||||
|
print(f"✅ Successfully processed identify request from {peer_id}")
|
||||||
|
|
||||||
|
# Replace the handler with our custom one
|
||||||
|
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, custom_identify_handler)
|
||||||
|
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -93,25 +126,99 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
|
|||||||
info = info_from_p2p_addr(maddr)
|
info = info_from_p2p_addr(maddr)
|
||||||
print(f"Second host connecting to peer: {info.peer_id}")
|
print(f"Second host connecting to peer: {info.peer_id}")
|
||||||
|
|
||||||
await host_b.connect(info)
|
try:
|
||||||
|
await host_b.connect(info)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
if "unable to connect" in error_msg or "SwarmException" in error_msg:
|
||||||
|
print(f"\n❌ Cannot connect to peer: {info.peer_id}")
|
||||||
|
print(f" Address: {destination}")
|
||||||
|
print(f" Error: {error_msg}")
|
||||||
|
print(
|
||||||
|
"\n💡 Make sure the peer is running and the address is correct."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Re-raise other exceptions
|
||||||
|
raise
|
||||||
|
|
||||||
stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,))
|
stream = await host_b.new_stream(info.peer_id, (IDENTIFY_PROTOCOL_ID,))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print("Starting identify protocol...")
|
print("Starting identify protocol...")
|
||||||
|
|
||||||
# Read the complete response (could be either format)
|
# Read the response properly based on the format
|
||||||
# Read a larger chunk to get all the data before stream closes
|
if use_varint_format:
|
||||||
response = await stream.read(8192) # Read enough data in one go
|
# For length-prefixed format, read varint length first
|
||||||
|
from libp2p.utils.varint import decode_varint_from_bytes
|
||||||
|
|
||||||
|
# Read varint length prefix
|
||||||
|
length_bytes = b""
|
||||||
|
while True:
|
||||||
|
b = await stream.read(1)
|
||||||
|
if not b:
|
||||||
|
raise Exception("Stream closed while reading varint length")
|
||||||
|
length_bytes += b
|
||||||
|
if b[0] & 0x80 == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
msg_length = decode_varint_from_bytes(length_bytes)
|
||||||
|
print(f"Expected message length: {msg_length} bytes")
|
||||||
|
|
||||||
|
# Read the protobuf message
|
||||||
|
response = await stream.read(msg_length)
|
||||||
|
if len(response) != msg_length:
|
||||||
|
raise Exception(
|
||||||
|
f"Incomplete message: expected {msg_length} bytes, "
|
||||||
|
f"got {len(response)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine length prefix and message
|
||||||
|
full_response = length_bytes + response
|
||||||
|
else:
|
||||||
|
# For raw format, read all available data
|
||||||
|
response = await stream.read(8192)
|
||||||
|
full_response = response
|
||||||
|
|
||||||
await stream.close()
|
await stream.close()
|
||||||
|
|
||||||
# Parse the response using the robust protocol-level function
|
# Parse the response using the robust protocol-level function
|
||||||
# This handles both old and new formats automatically
|
# This handles both old and new formats automatically
|
||||||
identify_msg = parse_identify_response(response)
|
identify_msg = parse_identify_response(full_response)
|
||||||
print_identify_response(identify_msg)
|
print_identify_response(identify_msg)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Identify protocol error: {e}")
|
error_msg = str(e)
|
||||||
|
print(f"Identify protocol error: {error_msg}")
|
||||||
|
|
||||||
|
# Check for specific format mismatch errors
|
||||||
|
if "Error parsing message" in error_msg or "DecodeError" in error_msg:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("FORMAT MISMATCH DETECTED!")
|
||||||
|
print("=" * 60)
|
||||||
|
if use_varint_format:
|
||||||
|
print(
|
||||||
|
"You are using length-prefixed format (default) but the "
|
||||||
|
"listener"
|
||||||
|
)
|
||||||
|
print("is using raw protobuf format.")
|
||||||
|
print(
|
||||||
|
"\nTo fix this, run the dialer with the --raw-format flag:"
|
||||||
|
)
|
||||||
|
print(f"identify-demo --raw-format -d {destination}")
|
||||||
|
else:
|
||||||
|
print("You are using raw protobuf format but the listener")
|
||||||
|
print("is using length-prefixed format (default).")
|
||||||
|
print(
|
||||||
|
"\nTo fix this, run the dialer without the --raw-format "
|
||||||
|
"flag:"
|
||||||
|
)
|
||||||
|
print(f"identify-demo -d {destination}")
|
||||||
|
print("=" * 60)
|
||||||
|
else:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
import base58
|
import base58
|
||||||
@ -36,25 +37,23 @@ if ENABLE_INLINING:
|
|||||||
|
|
||||||
class ID:
|
class ID:
|
||||||
_bytes: bytes
|
_bytes: bytes
|
||||||
_xor_id: int | None = None
|
|
||||||
_b58_str: str | None = None
|
|
||||||
|
|
||||||
def __init__(self, peer_id_bytes: bytes) -> None:
|
def __init__(self, peer_id_bytes: bytes) -> None:
|
||||||
self._bytes = peer_id_bytes
|
self._bytes = peer_id_bytes
|
||||||
|
|
||||||
@property
|
@functools.cached_property
|
||||||
def xor_id(self) -> int:
|
def xor_id(self) -> int:
|
||||||
if not self._xor_id:
|
return int(sha256_digest(self._bytes).hex(), 16)
|
||||||
self._xor_id = int(sha256_digest(self._bytes).hex(), 16)
|
|
||||||
return self._xor_id
|
@functools.cached_property
|
||||||
|
def base58(self) -> str:
|
||||||
|
return base58.b58encode(self._bytes).decode()
|
||||||
|
|
||||||
def to_bytes(self) -> bytes:
|
def to_bytes(self) -> bytes:
|
||||||
return self._bytes
|
return self._bytes
|
||||||
|
|
||||||
def to_base58(self) -> str:
|
def to_base58(self) -> str:
|
||||||
if not self._b58_str:
|
return self.base58
|
||||||
self._b58_str = base58.b58encode(self._bytes).decode()
|
|
||||||
return self._b58_str
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<libp2p.peer.id.ID ({self!s})>"
|
return f"<libp2p.peer.id.ID ({self!s})>"
|
||||||
|
|||||||
@ -102,6 +102,9 @@ class TopicValidator(NamedTuple):
|
|||||||
is_async: bool
|
is_async: bool
|
||||||
|
|
||||||
|
|
||||||
|
MAX_CONCURRENT_VALIDATORS = 10
|
||||||
|
|
||||||
|
|
||||||
class Pubsub(Service, IPubsub):
|
class Pubsub(Service, IPubsub):
|
||||||
host: IHost
|
host: IHost
|
||||||
|
|
||||||
@ -109,6 +112,7 @@ class Pubsub(Service, IPubsub):
|
|||||||
|
|
||||||
peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
||||||
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
||||||
|
_validator_semaphore: trio.Semaphore
|
||||||
|
|
||||||
seen_messages: LastSeenCache
|
seen_messages: LastSeenCache
|
||||||
|
|
||||||
@ -143,6 +147,7 @@ class Pubsub(Service, IPubsub):
|
|||||||
msg_id_constructor: Callable[
|
msg_id_constructor: Callable[
|
||||||
[rpc_pb2.Message], bytes
|
[rpc_pb2.Message], bytes
|
||||||
] = get_peer_and_seqno_msg_id,
|
] = get_peer_and_seqno_msg_id,
|
||||||
|
max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Construct a new Pubsub object, which is responsible for handling all
|
Construct a new Pubsub object, which is responsible for handling all
|
||||||
@ -168,6 +173,7 @@ class Pubsub(Service, IPubsub):
|
|||||||
# Therefore, we can only close from the receive side.
|
# Therefore, we can only close from the receive side.
|
||||||
self.peer_receive_channel = peer_receive
|
self.peer_receive_channel = peer_receive
|
||||||
self.dead_peer_receive_channel = dead_peer_receive
|
self.dead_peer_receive_channel = dead_peer_receive
|
||||||
|
self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count)
|
||||||
# Register a notifee
|
# Register a notifee
|
||||||
self.host.get_network().register_notifee(
|
self.host.get_network().register_notifee(
|
||||||
PubsubNotifee(peer_send, dead_peer_send)
|
PubsubNotifee(peer_send, dead_peer_send)
|
||||||
@ -657,7 +663,11 @@ class Pubsub(Service, IPubsub):
|
|||||||
|
|
||||||
logger.debug("successfully published message %s", msg)
|
logger.debug("successfully published message %s", msg)
|
||||||
|
|
||||||
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
async def validate_msg(
|
||||||
|
self,
|
||||||
|
msg_forwarder: ID,
|
||||||
|
msg: rpc_pb2.Message,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Validate the received message.
|
Validate the received message.
|
||||||
|
|
||||||
@ -680,23 +690,34 @@ class Pubsub(Service, IPubsub):
|
|||||||
if not validator(msg_forwarder, msg):
|
if not validator(msg_forwarder, msg):
|
||||||
raise ValidationError(f"Validation failed for msg={msg}")
|
raise ValidationError(f"Validation failed for msg={msg}")
|
||||||
|
|
||||||
# TODO: Implement throttle on async validators
|
|
||||||
|
|
||||||
if len(async_topic_validators) > 0:
|
if len(async_topic_validators) > 0:
|
||||||
# Appends to lists are thread safe in CPython
|
# Appends to lists are thread safe in CPython
|
||||||
results = []
|
results: list[bool] = []
|
||||||
|
|
||||||
async def run_async_validator(func: AsyncValidatorFn) -> None:
|
|
||||||
result = await func(msg_forwarder, msg)
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for async_validator in async_topic_validators:
|
for async_validator in async_topic_validators:
|
||||||
nursery.start_soon(run_async_validator, async_validator)
|
nursery.start_soon(
|
||||||
|
self._run_async_validator,
|
||||||
|
async_validator,
|
||||||
|
msg_forwarder,
|
||||||
|
msg,
|
||||||
|
results,
|
||||||
|
)
|
||||||
|
|
||||||
if not all(results):
|
if not all(results):
|
||||||
raise ValidationError(f"Validation failed for msg={msg}")
|
raise ValidationError(f"Validation failed for msg={msg}")
|
||||||
|
|
||||||
|
async def _run_async_validator(
|
||||||
|
self,
|
||||||
|
func: AsyncValidatorFn,
|
||||||
|
msg_forwarder: ID,
|
||||||
|
msg: rpc_pb2.Message,
|
||||||
|
results: list[bool],
|
||||||
|
) -> None:
|
||||||
|
async with self._validator_semaphore:
|
||||||
|
result = await func(msg_forwarder, msg)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||||
"""
|
"""
|
||||||
Push a pubsub message to others.
|
Push a pubsub message to others.
|
||||||
|
|||||||
2
newsfragments/755.performance.rst
Normal file
2
newsfragments/755.performance.rst
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
Added throttling for async topic validators in validate_msg, enforcing a
|
||||||
|
concurrency limit to prevent resource exhaustion under heavy load.
|
||||||
1
newsfragments/772.internal.rst
Normal file
1
newsfragments/772.internal.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Replace the libp2p.peer.ID cache attributes with functools.cached_property functional decorator.
|
||||||
1
newsfragments/775.docs.rst
Normal file
1
newsfragments/775.docs.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Clarified the requirement for a trailing newline in newsfragments to pass lint checks.
|
||||||
1
newsfragments/778.bugfix.rst
Normal file
1
newsfragments/778.bugfix.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fixed incorrect handling of raw protobuf format in identify protocol. The identify example now properly handles both raw and length-prefixed (varint) message formats, provides better error messages, and displays connection status with peer IDs. Replaced mock-based tests with comprehensive real network integration tests for both formats.
|
||||||
@ -18,12 +18,19 @@ Each file should be named like `<ISSUE>.<TYPE>.rst`, where
|
|||||||
- `performance`
|
- `performance`
|
||||||
- `removal`
|
- `removal`
|
||||||
|
|
||||||
So for example: `123.feature.rst`, `456.bugfix.rst`
|
So for example: `1024.feature.rst`
|
||||||
|
|
||||||
|
**Important**: Ensure the file ends with a newline character (`\n`) to pass GitHub tox linting checks.
|
||||||
|
|
||||||
|
```
|
||||||
|
Added support for Ed25519 key generation in libp2p peer identity creation.
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
If the PR fixes an issue, use that number here. If there is no issue,
|
If the PR fixes an issue, use that number here. If there is no issue,
|
||||||
then open up the PR first and use the PR number for the newsfragment.
|
then open up the PR first and use the PR number for the newsfragment.
|
||||||
|
|
||||||
Note that the `towncrier` tool will automatically
|
**Note** that the `towncrier` tool will automatically
|
||||||
reflow your text, so don't try to do any fancy formatting. Run
|
reflow your text, so don't try to do any fancy formatting. Run
|
||||||
`towncrier build --draft` to get a preview of what the release notes entry
|
`towncrier build --draft` to get a preview of what the release notes entry
|
||||||
will look like in the final release notes.
|
will look like in the final release notes.
|
||||||
|
|||||||
241
tests/core/identity/identify/test_identify_integration.py
Normal file
241
tests/core/identity/identify/test_identify_integration.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.custom_types import TProtocol
|
||||||
|
from libp2p.identity.identify.identify import (
|
||||||
|
AGENT_VERSION,
|
||||||
|
ID,
|
||||||
|
PROTOCOL_VERSION,
|
||||||
|
_multiaddr_to_bytes,
|
||||||
|
identify_handler_for,
|
||||||
|
parse_identify_response,
|
||||||
|
)
|
||||||
|
from tests.utils.factories import host_pair_factory
|
||||||
|
|
||||||
|
logger = logging.getLogger("libp2p.identity.identify-integration-test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_protocol_varint_format_integration(security_protocol):
|
||||||
|
"""Test identify protocol with varint format in real network scenario."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
host_a.set_stream_handler(
|
||||||
|
ID, identify_handler_for(host_a, use_varint_format=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make identify request
|
||||||
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
|
response = await stream.read(8192)
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
result = parse_identify_response(response)
|
||||||
|
|
||||||
|
# Verify response content
|
||||||
|
assert result.agent_version == AGENT_VERSION
|
||||||
|
assert result.protocol_version == PROTOCOL_VERSION
|
||||||
|
assert result.public_key == host_a.get_public_key().serialize()
|
||||||
|
assert result.listen_addrs == [
|
||||||
|
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_protocol_raw_format_integration(security_protocol):
|
||||||
|
"""Test identify protocol with raw format in real network scenario."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
host_a.set_stream_handler(
|
||||||
|
ID, identify_handler_for(host_a, use_varint_format=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make identify request
|
||||||
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
|
response = await stream.read(8192)
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
result = parse_identify_response(response)
|
||||||
|
|
||||||
|
# Verify response content
|
||||||
|
assert result.agent_version == AGENT_VERSION
|
||||||
|
assert result.protocol_version == PROTOCOL_VERSION
|
||||||
|
assert result.public_key == host_a.get_public_key().serialize()
|
||||||
|
assert result.listen_addrs == [
|
||||||
|
_multiaddr_to_bytes(addr) for addr in host_a.get_addrs()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_default_format_behavior(security_protocol):
|
||||||
|
"""Test identify protocol uses correct default format."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
# Use default identify handler (should use varint format)
|
||||||
|
host_a.set_stream_handler(ID, identify_handler_for(host_a))
|
||||||
|
|
||||||
|
# Make identify request
|
||||||
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
|
response = await stream.read(8192)
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
result = parse_identify_response(response)
|
||||||
|
|
||||||
|
# Verify response content
|
||||||
|
assert result.agent_version == AGENT_VERSION
|
||||||
|
assert result.protocol_version == PROTOCOL_VERSION
|
||||||
|
assert result.public_key == host_a.get_public_key().serialize()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_cross_format_compatibility_varint_to_raw(security_protocol):
|
||||||
|
"""Test varint dialer with raw listener compatibility."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
# Host A uses raw format
|
||||||
|
host_a.set_stream_handler(
|
||||||
|
ID, identify_handler_for(host_a, use_varint_format=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Host B makes request (will automatically detect format)
|
||||||
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
|
response = await stream.read(8192)
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
# Parse response (should work with automatic format detection)
|
||||||
|
result = parse_identify_response(response)
|
||||||
|
|
||||||
|
# Verify response content
|
||||||
|
assert result.agent_version == AGENT_VERSION
|
||||||
|
assert result.protocol_version == PROTOCOL_VERSION
|
||||||
|
assert result.public_key == host_a.get_public_key().serialize()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_cross_format_compatibility_raw_to_varint(security_protocol):
|
||||||
|
"""Test raw dialer with varint listener compatibility."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
# Host A uses varint format
|
||||||
|
host_a.set_stream_handler(
|
||||||
|
ID, identify_handler_for(host_a, use_varint_format=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Host B makes request (will automatically detect format)
|
||||||
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
|
response = await stream.read(8192)
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
# Parse response (should work with automatic format detection)
|
||||||
|
result = parse_identify_response(response)
|
||||||
|
|
||||||
|
# Verify response content
|
||||||
|
assert result.agent_version == AGENT_VERSION
|
||||||
|
assert result.protocol_version == PROTOCOL_VERSION
|
||||||
|
assert result.public_key == host_a.get_public_key().serialize()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_format_detection_robustness(security_protocol):
|
||||||
|
"""Test identify protocol format detection is robust with various message sizes."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
# Test both formats with different message sizes
|
||||||
|
for use_varint in [True, False]:
|
||||||
|
host_a.set_stream_handler(
|
||||||
|
ID, identify_handler_for(host_a, use_varint_format=use_varint)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make identify request
|
||||||
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
|
response = await stream.read(8192)
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
result = parse_identify_response(response)
|
||||||
|
|
||||||
|
# Verify response content
|
||||||
|
assert result.agent_version == AGENT_VERSION
|
||||||
|
assert result.protocol_version == PROTOCOL_VERSION
|
||||||
|
assert result.public_key == host_a.get_public_key().serialize()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_large_message_handling(security_protocol):
|
||||||
|
"""Test identify protocol handles large messages with many protocols."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
# Add many protocols to make the message larger
|
||||||
|
async def dummy_handler(stream):
|
||||||
|
pass
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
host_a.set_stream_handler(TProtocol(f"/test/protocol/{i}"), dummy_handler)
|
||||||
|
|
||||||
|
host_a.set_stream_handler(
|
||||||
|
ID, identify_handler_for(host_a, use_varint_format=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make identify request
|
||||||
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
|
response = await stream.read(8192)
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
result = parse_identify_response(response)
|
||||||
|
|
||||||
|
# Verify response content
|
||||||
|
assert result.agent_version == AGENT_VERSION
|
||||||
|
assert result.protocol_version == PROTOCOL_VERSION
|
||||||
|
assert result.public_key == host_a.get_public_key().serialize()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_identify_message_equivalence_real_network(security_protocol):
|
||||||
|
"""Test that both formats produce equivalent messages in real network."""
|
||||||
|
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||||
|
host_a,
|
||||||
|
host_b,
|
||||||
|
):
|
||||||
|
# Test varint format
|
||||||
|
host_a.set_stream_handler(
|
||||||
|
ID, identify_handler_for(host_a, use_varint_format=True)
|
||||||
|
)
|
||||||
|
stream_varint = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
|
response_varint = await stream_varint.read(8192)
|
||||||
|
await stream_varint.close()
|
||||||
|
|
||||||
|
# Test raw format
|
||||||
|
host_a.set_stream_handler(
|
||||||
|
ID, identify_handler_for(host_a, use_varint_format=False)
|
||||||
|
)
|
||||||
|
stream_raw = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
|
response_raw = await stream_raw.read(8192)
|
||||||
|
await stream_raw.close()
|
||||||
|
|
||||||
|
# Parse both responses
|
||||||
|
result_varint = parse_identify_response(response_varint)
|
||||||
|
result_raw = parse_identify_response(response_raw)
|
||||||
|
|
||||||
|
# Both should produce identical parsed results
|
||||||
|
assert result_varint.agent_version == result_raw.agent_version
|
||||||
|
assert result_varint.protocol_version == result_raw.protocol_version
|
||||||
|
assert result_varint.public_key == result_raw.public_key
|
||||||
|
assert result_varint.listen_addrs == result_raw.listen_addrs
|
||||||
@ -1,410 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from libp2p.identity.identify.identify import (
|
|
||||||
_mk_identify_protobuf,
|
|
||||||
)
|
|
||||||
from libp2p.identity.identify.pb.identify_pb2 import (
|
|
||||||
Identify,
|
|
||||||
)
|
|
||||||
from libp2p.io.abc import Closer, Reader, Writer
|
|
||||||
from libp2p.utils.varint import (
|
|
||||||
decode_varint_from_bytes,
|
|
||||||
encode_varint_prefixed,
|
|
||||||
)
|
|
||||||
from tests.utils.factories import (
|
|
||||||
host_pair_factory,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MockStream(Reader, Writer, Closer):
|
|
||||||
"""Mock stream for testing identify protocol compatibility."""
|
|
||||||
|
|
||||||
def __init__(self, data: bytes):
|
|
||||||
self.data = data
|
|
||||||
self.position = 0
|
|
||||||
self.closed = False
|
|
||||||
|
|
||||||
async def read(self, n: int | None = None) -> bytes:
|
|
||||||
if self.closed or self.position >= len(self.data):
|
|
||||||
return b""
|
|
||||||
if n is None:
|
|
||||||
n = len(self.data) - self.position
|
|
||||||
result = self.data[self.position : self.position + n]
|
|
||||||
self.position += len(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
|
||||||
# Mock write - just store the data
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
self.closed = True
|
|
||||||
|
|
||||||
|
|
||||||
def create_identify_message(host, observed_multiaddr=None):
|
|
||||||
"""Create an identify protobuf message."""
|
|
||||||
return _mk_identify_protobuf(host, observed_multiaddr)
|
|
||||||
|
|
||||||
|
|
||||||
def create_new_format_message(identify_msg):
|
|
||||||
"""Create a new format (length-prefixed) identify message."""
|
|
||||||
msg_bytes = identify_msg.SerializeToString()
|
|
||||||
return encode_varint_prefixed(msg_bytes)
|
|
||||||
|
|
||||||
|
|
||||||
def create_old_format_message(identify_msg):
|
|
||||||
"""Create an old format (raw protobuf) identify message."""
|
|
||||||
return identify_msg.SerializeToString()
|
|
||||||
|
|
||||||
|
|
||||||
async def read_new_format_message(stream) -> bytes:
|
|
||||||
"""Read a new format (length-prefixed) identify message."""
|
|
||||||
# Read varint length prefix
|
|
||||||
length_bytes = b""
|
|
||||||
while True:
|
|
||||||
b = await stream.read(1)
|
|
||||||
if not b:
|
|
||||||
break
|
|
||||||
length_bytes += b
|
|
||||||
if b[0] & 0x80 == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not length_bytes:
|
|
||||||
raise ValueError("No length prefix received")
|
|
||||||
|
|
||||||
msg_length = decode_varint_from_bytes(length_bytes)
|
|
||||||
|
|
||||||
# Read the protobuf message
|
|
||||||
response = await stream.read(msg_length)
|
|
||||||
if len(response) != msg_length:
|
|
||||||
raise ValueError("Incomplete message received")
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
async def read_old_format_message(stream) -> bytes:
|
|
||||||
"""Read an old format (raw protobuf) identify message."""
|
|
||||||
# Read all available data
|
|
||||||
response = b""
|
|
||||||
while True:
|
|
||||||
chunk = await stream.read(4096)
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
response += chunk
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
async def read_compatible_message(stream) -> bytes:
|
|
||||||
"""Read an identify message in either old or new format."""
|
|
||||||
# Try to read a few bytes to detect the format
|
|
||||||
first_bytes = await stream.read(10)
|
|
||||||
if not first_bytes:
|
|
||||||
raise ValueError("No data received")
|
|
||||||
|
|
||||||
# Try to decode as varint length prefix (new format)
|
|
||||||
try:
|
|
||||||
msg_length = decode_varint_from_bytes(first_bytes)
|
|
||||||
|
|
||||||
# Validate that the length is reasonable (not too large)
|
|
||||||
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
|
|
||||||
# Calculate how many bytes the varint consumed
|
|
||||||
varint_len = 0
|
|
||||||
for i, byte in enumerate(first_bytes):
|
|
||||||
varint_len += 1
|
|
||||||
if (byte & 0x80) == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Read the remaining protobuf message
|
|
||||||
remaining_bytes = await stream.read(
|
|
||||||
msg_length - (len(first_bytes) - varint_len)
|
|
||||||
)
|
|
||||||
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
|
|
||||||
message_data = first_bytes[varint_len:] + remaining_bytes
|
|
||||||
|
|
||||||
# Try to parse as protobuf to validate
|
|
||||||
try:
|
|
||||||
Identify().ParseFromString(message_data)
|
|
||||||
return message_data
|
|
||||||
except Exception:
|
|
||||||
# If protobuf parsing fails, fall back to old format
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Fall back to old format (raw protobuf)
|
|
||||||
response = first_bytes
|
|
||||||
|
|
||||||
# Read more data if available
|
|
||||||
while True:
|
|
||||||
chunk = await stream.read(4096)
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
response += chunk
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
async def read_compatible_message_simple(stream) -> bytes:
|
|
||||||
"""Read a message in either old or new format (simplified version for testing)."""
|
|
||||||
# Try to read a few bytes to detect the format
|
|
||||||
first_bytes = await stream.read(10)
|
|
||||||
if not first_bytes:
|
|
||||||
raise ValueError("No data received")
|
|
||||||
|
|
||||||
# Try to decode as varint length prefix (new format)
|
|
||||||
try:
|
|
||||||
msg_length = decode_varint_from_bytes(first_bytes)
|
|
||||||
|
|
||||||
# Validate that the length is reasonable (not too large)
|
|
||||||
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
|
|
||||||
# Calculate how many bytes the varint consumed
|
|
||||||
varint_len = 0
|
|
||||||
for i, byte in enumerate(first_bytes):
|
|
||||||
varint_len += 1
|
|
||||||
if (byte & 0x80) == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Read the remaining message
|
|
||||||
remaining_bytes = await stream.read(
|
|
||||||
msg_length - (len(first_bytes) - varint_len)
|
|
||||||
)
|
|
||||||
if len(remaining_bytes) == msg_length - (len(first_bytes) - varint_len):
|
|
||||||
return first_bytes[varint_len:] + remaining_bytes
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Fall back to old format (raw data)
|
|
||||||
response = first_bytes
|
|
||||||
|
|
||||||
# Read more data if available
|
|
||||||
while True:
|
|
||||||
chunk = await stream.read(4096)
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
response += chunk
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def detect_format(data):
|
|
||||||
"""Detect if data is in new or old format (varint-prefixed or raw protobuf)."""
|
|
||||||
if not data:
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
# Try to decode as varint
|
|
||||||
try:
|
|
||||||
msg_length = decode_varint_from_bytes(data)
|
|
||||||
|
|
||||||
# Validate that the length is reasonable
|
|
||||||
if msg_length > 0 and msg_length <= 1024 * 1024: # Max 1MB
|
|
||||||
# Calculate varint length
|
|
||||||
varint_len = 0
|
|
||||||
for i, byte in enumerate(data):
|
|
||||||
varint_len += 1
|
|
||||||
if (byte & 0x80) == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check if we have enough data for the message
|
|
||||||
if len(data) >= varint_len + msg_length:
|
|
||||||
# Additional check: try to parse the message as protobuf
|
|
||||||
try:
|
|
||||||
message_data = data[varint_len : varint_len + msg_length]
|
|
||||||
Identify().ParseFromString(message_data)
|
|
||||||
return "new"
|
|
||||||
except Exception:
|
|
||||||
# If protobuf parsing fails, it's probably not a valid new format
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# If varint decoding fails or length is unreasonable, assume old format
|
|
||||||
return "old"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_identify_new_format_compatibility(security_protocol):
|
|
||||||
"""Test that identify protocol works with new format (length-prefixed) messages."""
|
|
||||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
|
||||||
host_a,
|
|
||||||
host_b,
|
|
||||||
):
|
|
||||||
# Create identify message
|
|
||||||
identify_msg = create_identify_message(host_a)
|
|
||||||
|
|
||||||
# Create new format message
|
|
||||||
new_format_data = create_new_format_message(identify_msg)
|
|
||||||
|
|
||||||
# Create mock stream with new format data
|
|
||||||
stream = MockStream(new_format_data)
|
|
||||||
|
|
||||||
# Read using new format reader
|
|
||||||
response = await read_new_format_message(stream)
|
|
||||||
|
|
||||||
# Parse the response
|
|
||||||
parsed_msg = Identify()
|
|
||||||
parsed_msg.ParseFromString(response)
|
|
||||||
|
|
||||||
# Verify the message content
|
|
||||||
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
|
||||||
assert parsed_msg.agent_version == identify_msg.agent_version
|
|
||||||
assert parsed_msg.public_key == identify_msg.public_key
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_identify_old_format_compatibility(security_protocol):
|
|
||||||
"""Test that identify protocol works with old format (raw protobuf) messages."""
|
|
||||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
|
||||||
host_a,
|
|
||||||
host_b,
|
|
||||||
):
|
|
||||||
# Create identify message
|
|
||||||
identify_msg = create_identify_message(host_a)
|
|
||||||
|
|
||||||
# Create old format message
|
|
||||||
old_format_data = create_old_format_message(identify_msg)
|
|
||||||
|
|
||||||
# Create mock stream with old format data
|
|
||||||
stream = MockStream(old_format_data)
|
|
||||||
|
|
||||||
# Read using old format reader
|
|
||||||
response = await read_old_format_message(stream)
|
|
||||||
|
|
||||||
# Parse the response
|
|
||||||
parsed_msg = Identify()
|
|
||||||
parsed_msg.ParseFromString(response)
|
|
||||||
|
|
||||||
# Verify the message content
|
|
||||||
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
|
||||||
assert parsed_msg.agent_version == identify_msg.agent_version
|
|
||||||
assert parsed_msg.public_key == identify_msg.public_key
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_identify_backward_compatibility_old_format(security_protocol):
|
|
||||||
"""Test backward compatibility reader with old format messages."""
|
|
||||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
|
||||||
host_a,
|
|
||||||
host_b,
|
|
||||||
):
|
|
||||||
# Create identify message
|
|
||||||
identify_msg = create_identify_message(host_a)
|
|
||||||
|
|
||||||
# Create old format message
|
|
||||||
old_format_data = create_old_format_message(identify_msg)
|
|
||||||
|
|
||||||
# Create mock stream with old format data
|
|
||||||
stream = MockStream(old_format_data)
|
|
||||||
|
|
||||||
# Read using old format reader (which should work reliably)
|
|
||||||
response = await read_old_format_message(stream)
|
|
||||||
|
|
||||||
# Parse the response
|
|
||||||
parsed_msg = Identify()
|
|
||||||
parsed_msg.ParseFromString(response)
|
|
||||||
|
|
||||||
# Verify the message content
|
|
||||||
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
|
||||||
assert parsed_msg.agent_version == identify_msg.agent_version
|
|
||||||
assert parsed_msg.public_key == identify_msg.public_key
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_identify_backward_compatibility_new_format(security_protocol):
|
|
||||||
"""Test backward compatibility reader with new format messages."""
|
|
||||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
|
||||||
host_a,
|
|
||||||
host_b,
|
|
||||||
):
|
|
||||||
# Create identify message
|
|
||||||
identify_msg = create_identify_message(host_a)
|
|
||||||
|
|
||||||
# Create new format message
|
|
||||||
new_format_data = create_new_format_message(identify_msg)
|
|
||||||
|
|
||||||
# Create mock stream with new format data
|
|
||||||
stream = MockStream(new_format_data)
|
|
||||||
|
|
||||||
# Read using new format reader (which should work reliably)
|
|
||||||
response = await read_new_format_message(stream)
|
|
||||||
|
|
||||||
# Parse the response
|
|
||||||
parsed_msg = Identify()
|
|
||||||
parsed_msg.ParseFromString(response)
|
|
||||||
|
|
||||||
# Verify the message content
|
|
||||||
assert parsed_msg.protocol_version == identify_msg.protocol_version
|
|
||||||
assert parsed_msg.agent_version == identify_msg.agent_version
|
|
||||||
assert parsed_msg.public_key == identify_msg.public_key
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_identify_format_detection(security_protocol):
|
|
||||||
"""Test that the format detection works correctly."""
|
|
||||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
|
||||||
host_a,
|
|
||||||
host_b,
|
|
||||||
):
|
|
||||||
# Create identify message
|
|
||||||
identify_msg = create_identify_message(host_a)
|
|
||||||
|
|
||||||
# Test new format detection
|
|
||||||
new_format_data = create_new_format_message(identify_msg)
|
|
||||||
format_type = detect_format(new_format_data)
|
|
||||||
assert format_type == "new", "New format should be detected correctly"
|
|
||||||
|
|
||||||
# Test old format detection
|
|
||||||
old_format_data = create_old_format_message(identify_msg)
|
|
||||||
format_type = detect_format(old_format_data)
|
|
||||||
assert format_type == "old", "Old format should be detected correctly"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_identify_error_handling(security_protocol):
|
|
||||||
"""Test error handling for malformed messages."""
|
|
||||||
from libp2p.exceptions import ParseError
|
|
||||||
|
|
||||||
# Test with empty data
|
|
||||||
stream = MockStream(b"")
|
|
||||||
with pytest.raises(ValueError, match="No data received"):
|
|
||||||
await read_compatible_message(stream)
|
|
||||||
|
|
||||||
# Test with incomplete varint
|
|
||||||
stream = MockStream(b"\x80") # Incomplete varint
|
|
||||||
with pytest.raises(ParseError, match="Unexpected end of data"):
|
|
||||||
await read_new_format_message(stream)
|
|
||||||
|
|
||||||
# Test with invalid protobuf data
|
|
||||||
stream = MockStream(b"\x05invalid") # Length prefix but invalid protobuf
|
|
||||||
with pytest.raises(Exception): # Should fail when parsing protobuf
|
|
||||||
response = await read_new_format_message(stream)
|
|
||||||
Identify().ParseFromString(response)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_identify_message_equivalence(security_protocol):
|
|
||||||
"""Test that old and new format messages are equivalent."""
|
|
||||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
|
||||||
host_a,
|
|
||||||
host_b,
|
|
||||||
):
|
|
||||||
# Create identify message
|
|
||||||
identify_msg = create_identify_message(host_a)
|
|
||||||
|
|
||||||
# Create both formats
|
|
||||||
new_format_data = create_new_format_message(identify_msg)
|
|
||||||
old_format_data = create_old_format_message(identify_msg)
|
|
||||||
|
|
||||||
# Extract the protobuf message from new format
|
|
||||||
varint_len = 0
|
|
||||||
for i, byte in enumerate(new_format_data):
|
|
||||||
varint_len += 1
|
|
||||||
if (byte & 0x80) == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
new_format_protobuf = new_format_data[varint_len:]
|
|
||||||
|
|
||||||
# The protobuf messages should be identical
|
|
||||||
assert new_format_protobuf == old_format_data, (
|
|
||||||
"Protobuf messages should be identical in both formats"
|
|
||||||
)
|
|
||||||
@ -5,10 +5,12 @@ import inspect
|
|||||||
from typing import (
|
from typing import (
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
)
|
)
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
from libp2p.custom_types import AsyncValidatorFn
|
||||||
from libp2p.exceptions import (
|
from libp2p.exceptions import (
|
||||||
ValidationError,
|
ValidationError,
|
||||||
)
|
)
|
||||||
@ -243,7 +245,37 @@ async def test_get_msg_validators():
|
|||||||
((False, True), (True, False), (True, True)),
|
((False, True), (True, False), (True, True)),
|
||||||
)
|
)
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
|
async def test_validate_msg_with_throttle_condition(
|
||||||
|
is_topic_1_val_passed, is_topic_2_val_passed
|
||||||
|
):
|
||||||
|
CONCURRENCY_LIMIT = 10
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"concurrency_counter": 0,
|
||||||
|
"max_observed": 0,
|
||||||
|
}
|
||||||
|
lock = trio.Lock()
|
||||||
|
|
||||||
|
async def mock_run_async_validator(
|
||||||
|
self,
|
||||||
|
func: AsyncValidatorFn,
|
||||||
|
msg_forwarder: ID,
|
||||||
|
msg: rpc_pb2.Message,
|
||||||
|
results: list[bool],
|
||||||
|
) -> None:
|
||||||
|
async with self._validator_semaphore:
|
||||||
|
async with lock:
|
||||||
|
state["concurrency_counter"] += 1
|
||||||
|
if state["concurrency_counter"] > state["max_observed"]:
|
||||||
|
state["max_observed"] = state["concurrency_counter"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await func(msg_forwarder, msg)
|
||||||
|
results.append(result)
|
||||||
|
finally:
|
||||||
|
async with lock:
|
||||||
|
state["concurrency_counter"] -= 1
|
||||||
|
|
||||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||||
|
|
||||||
def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||||||
@ -280,11 +312,19 @@ async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
|
|||||||
seqno=b"\x00" * 8,
|
seqno=b"\x00" * 8,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_topic_1_val_passed and is_topic_2_val_passed:
|
with patch(
|
||||||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
"libp2p.pubsub.pubsub.Pubsub._run_async_validator",
|
||||||
else:
|
new=mock_run_async_validator,
|
||||||
with pytest.raises(ValidationError):
|
):
|
||||||
|
if is_topic_1_val_passed and is_topic_2_val_passed:
|
||||||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||||
|
|
||||||
|
assert state["max_observed"] <= CONCURRENCY_LIMIT, (
|
||||||
|
f"Max concurrency observed: {state['max_observed']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
|
|||||||
Reference in New Issue
Block a user