mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
Merge branch 'main' into fix/issue-778-Incorrect_handling_of_raw_format_in_identify
This commit is contained in:
@ -102,6 +102,9 @@ class TopicValidator(NamedTuple):
|
|||||||
is_async: bool
|
is_async: bool
|
||||||
|
|
||||||
|
|
||||||
|
MAX_CONCURRENT_VALIDATORS = 10
|
||||||
|
|
||||||
|
|
||||||
class Pubsub(Service, IPubsub):
|
class Pubsub(Service, IPubsub):
|
||||||
host: IHost
|
host: IHost
|
||||||
|
|
||||||
@ -109,6 +112,7 @@ class Pubsub(Service, IPubsub):
|
|||||||
|
|
||||||
peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
||||||
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
||||||
|
_validator_semaphore: trio.Semaphore
|
||||||
|
|
||||||
seen_messages: LastSeenCache
|
seen_messages: LastSeenCache
|
||||||
|
|
||||||
@ -143,6 +147,7 @@ class Pubsub(Service, IPubsub):
|
|||||||
msg_id_constructor: Callable[
|
msg_id_constructor: Callable[
|
||||||
[rpc_pb2.Message], bytes
|
[rpc_pb2.Message], bytes
|
||||||
] = get_peer_and_seqno_msg_id,
|
] = get_peer_and_seqno_msg_id,
|
||||||
|
max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Construct a new Pubsub object, which is responsible for handling all
|
Construct a new Pubsub object, which is responsible for handling all
|
||||||
@ -168,6 +173,7 @@ class Pubsub(Service, IPubsub):
|
|||||||
# Therefore, we can only close from the receive side.
|
# Therefore, we can only close from the receive side.
|
||||||
self.peer_receive_channel = peer_receive
|
self.peer_receive_channel = peer_receive
|
||||||
self.dead_peer_receive_channel = dead_peer_receive
|
self.dead_peer_receive_channel = dead_peer_receive
|
||||||
|
self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count)
|
||||||
# Register a notifee
|
# Register a notifee
|
||||||
self.host.get_network().register_notifee(
|
self.host.get_network().register_notifee(
|
||||||
PubsubNotifee(peer_send, dead_peer_send)
|
PubsubNotifee(peer_send, dead_peer_send)
|
||||||
@ -657,7 +663,11 @@ class Pubsub(Service, IPubsub):
|
|||||||
|
|
||||||
logger.debug("successfully published message %s", msg)
|
logger.debug("successfully published message %s", msg)
|
||||||
|
|
||||||
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
async def validate_msg(
|
||||||
|
self,
|
||||||
|
msg_forwarder: ID,
|
||||||
|
msg: rpc_pb2.Message,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Validate the received message.
|
Validate the received message.
|
||||||
|
|
||||||
@ -680,23 +690,34 @@ class Pubsub(Service, IPubsub):
|
|||||||
if not validator(msg_forwarder, msg):
|
if not validator(msg_forwarder, msg):
|
||||||
raise ValidationError(f"Validation failed for msg={msg}")
|
raise ValidationError(f"Validation failed for msg={msg}")
|
||||||
|
|
||||||
# TODO: Implement throttle on async validators
|
|
||||||
|
|
||||||
if len(async_topic_validators) > 0:
|
if len(async_topic_validators) > 0:
|
||||||
# Appends to lists are thread safe in CPython
|
# Appends to lists are thread safe in CPython
|
||||||
results = []
|
results: list[bool] = []
|
||||||
|
|
||||||
async def run_async_validator(func: AsyncValidatorFn) -> None:
|
|
||||||
result = await func(msg_forwarder, msg)
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for async_validator in async_topic_validators:
|
for async_validator in async_topic_validators:
|
||||||
nursery.start_soon(run_async_validator, async_validator)
|
nursery.start_soon(
|
||||||
|
self._run_async_validator,
|
||||||
|
async_validator,
|
||||||
|
msg_forwarder,
|
||||||
|
msg,
|
||||||
|
results,
|
||||||
|
)
|
||||||
|
|
||||||
if not all(results):
|
if not all(results):
|
||||||
raise ValidationError(f"Validation failed for msg={msg}")
|
raise ValidationError(f"Validation failed for msg={msg}")
|
||||||
|
|
||||||
|
async def _run_async_validator(
|
||||||
|
self,
|
||||||
|
func: AsyncValidatorFn,
|
||||||
|
msg_forwarder: ID,
|
||||||
|
msg: rpc_pb2.Message,
|
||||||
|
results: list[bool],
|
||||||
|
) -> None:
|
||||||
|
async with self._validator_semaphore:
|
||||||
|
result = await func(msg_forwarder, msg)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||||
"""
|
"""
|
||||||
Push a pubsub message to others.
|
Push a pubsub message to others.
|
||||||
|
|||||||
2
newsfragments/755.performance.rst
Normal file
2
newsfragments/755.performance.rst
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
Added throttling for async topic validators in validate_msg, enforcing a
|
||||||
|
concurrency limit to prevent resource exhaustion under heavy load.
|
||||||
1
newsfragments/775.docs.rst
Normal file
1
newsfragments/775.docs.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Clarified the requirement for a trailing newline in newsfragments to pass lint checks.
|
||||||
@ -18,12 +18,19 @@ Each file should be named like `<ISSUE>.<TYPE>.rst`, where
|
|||||||
- `performance`
|
- `performance`
|
||||||
- `removal`
|
- `removal`
|
||||||
|
|
||||||
So for example: `123.feature.rst`, `456.bugfix.rst`
|
So for example: `1024.feature.rst`
|
||||||
|
|
||||||
|
**Important**: Ensure the file ends with a newline character (`\n`) to pass GitHub tox linting checks.
|
||||||
|
|
||||||
|
```
|
||||||
|
Added support for Ed25519 key generation in libp2p peer identity creation.
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
If the PR fixes an issue, use that number here. If there is no issue,
|
If the PR fixes an issue, use that number here. If there is no issue,
|
||||||
then open up the PR first and use the PR number for the newsfragment.
|
then open up the PR first and use the PR number for the newsfragment.
|
||||||
|
|
||||||
Note that the `towncrier` tool will automatically
|
**Note** that the `towncrier` tool will automatically
|
||||||
reflow your text, so don't try to do any fancy formatting. Run
|
reflow your text, so don't try to do any fancy formatting. Run
|
||||||
`towncrier build --draft` to get a preview of what the release notes entry
|
`towncrier build --draft` to get a preview of what the release notes entry
|
||||||
will look like in the final release notes.
|
will look like in the final release notes.
|
||||||
|
|||||||
@ -5,10 +5,12 @@ import inspect
|
|||||||
from typing import (
|
from typing import (
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
)
|
)
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
from libp2p.custom_types import AsyncValidatorFn
|
||||||
from libp2p.exceptions import (
|
from libp2p.exceptions import (
|
||||||
ValidationError,
|
ValidationError,
|
||||||
)
|
)
|
||||||
@ -243,7 +245,37 @@ async def test_get_msg_validators():
|
|||||||
((False, True), (True, False), (True, True)),
|
((False, True), (True, False), (True, True)),
|
||||||
)
|
)
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
|
async def test_validate_msg_with_throttle_condition(
|
||||||
|
is_topic_1_val_passed, is_topic_2_val_passed
|
||||||
|
):
|
||||||
|
CONCURRENCY_LIMIT = 10
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"concurrency_counter": 0,
|
||||||
|
"max_observed": 0,
|
||||||
|
}
|
||||||
|
lock = trio.Lock()
|
||||||
|
|
||||||
|
async def mock_run_async_validator(
|
||||||
|
self,
|
||||||
|
func: AsyncValidatorFn,
|
||||||
|
msg_forwarder: ID,
|
||||||
|
msg: rpc_pb2.Message,
|
||||||
|
results: list[bool],
|
||||||
|
) -> None:
|
||||||
|
async with self._validator_semaphore:
|
||||||
|
async with lock:
|
||||||
|
state["concurrency_counter"] += 1
|
||||||
|
if state["concurrency_counter"] > state["max_observed"]:
|
||||||
|
state["max_observed"] = state["concurrency_counter"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await func(msg_forwarder, msg)
|
||||||
|
results.append(result)
|
||||||
|
finally:
|
||||||
|
async with lock:
|
||||||
|
state["concurrency_counter"] -= 1
|
||||||
|
|
||||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||||
|
|
||||||
def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
|
||||||
@ -280,11 +312,19 @@ async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
|
|||||||
seqno=b"\x00" * 8,
|
seqno=b"\x00" * 8,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_topic_1_val_passed and is_topic_2_val_passed:
|
with patch(
|
||||||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
"libp2p.pubsub.pubsub.Pubsub._run_async_validator",
|
||||||
else:
|
new=mock_run_async_validator,
|
||||||
with pytest.raises(ValidationError):
|
):
|
||||||
|
if is_topic_1_val_passed and is_topic_2_val_passed:
|
||||||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||||
|
|
||||||
|
assert state["max_observed"] <= CONCURRENCY_LIMIT, (
|
||||||
|
f"Max concurrency observed: {state['max_observed']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
|
|||||||
Reference in New Issue
Block a user