Implement validation throttler for message validation in Pubsub

Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
This commit is contained in:
2025-06-27 20:29:54 +05:30
parent 6a92fa26eb
commit 3431efe97f
2 changed files with 411 additions and 44 deletions

View File

@ -11,10 +11,6 @@ import functools
import hashlib
import logging
import time
from typing import (
NamedTuple,
cast,
)
import base58
import trio
@ -30,8 +26,6 @@ from libp2p.crypto.keys import (
PrivateKey,
)
from libp2p.custom_types import (
AsyncValidatorFn,
SyncValidatorFn,
TProtocol,
ValidatorFn,
)
@ -76,6 +70,11 @@ from .pubsub_notifee import (
from .subscription import (
TrioSubscriptionAPI,
)
from .validation_throttler import (
TopicValidator,
ValidationResult,
ValidationThrottler,
)
from .validators import (
PUBSUB_SIGNING_PREFIX,
signature_validator,
@ -96,11 +95,6 @@ def get_content_addressed_msg_id(msg: rpc_pb2.Message) -> bytes:
return base64.b64encode(hashlib.sha256(msg.data).digest())
class TopicValidator(NamedTuple):
validator: ValidatorFn
is_async: bool
class Pubsub(Service, IPubsub):
host: IHost
@ -142,6 +136,11 @@ class Pubsub(Service, IPubsub):
msg_id_constructor: Callable[
[rpc_pb2.Message], bytes
] = get_peer_and_seqno_msg_id,
# TODO: these values have been copied from Go, but try to tune these dynamically
validation_queue_size: int = 32,
global_throttle_limit: int = 8192,
default_topic_throttle_limit: int = 1024,
validation_worker_count: int | None = None,
) -> None:
"""
Construct a new Pubsub object, which is responsible for handling all
@ -202,7 +201,15 @@ class Pubsub(Service, IPubsub):
# Create peers map, which maps peer_id (as string) to stream (to a given peer)
self.peers = {}
# Map of topic to topic validator
# Validation Throttler
self.validation_throttler = ValidationThrottler(
queue_size=validation_queue_size,
global_throttle_limit=global_throttle_limit,
default_topic_throttle_limit=default_topic_throttle_limit,
worker_count=validation_worker_count or 4,
)
# Keep a mapping of topic -> TopicValidator for easier lookup
self.topic_validators = {}
self.counter = int(time.time())
@ -214,10 +221,19 @@ class Pubsub(Service, IPubsub):
self.event_handle_dead_peer_queue_started = trio.Event()
async def run(self) -> None:
self.manager.run_daemon_task(self._start_validation_throttler)
self.manager.run_daemon_task(self.handle_peer_queue)
self.manager.run_daemon_task(self.handle_dead_peer_queue)
await self.manager.wait_finished()
async def _start_validation_throttler(self) -> None:
"""Start validation throttler in current nursery context"""
async with trio.open_nursery() as nursery:
await self.validation_throttler.start(nursery)
# Keep nursery alive until service stops
while self.manager.is_running:
await trio.sleep(1)
@property
def my_id(self) -> ID:
return self.host.get_id()
@ -297,7 +313,12 @@ class Pubsub(Service, IPubsub):
)
def set_topic_validator(
self, topic: str, validator: ValidatorFn, is_async_validator: bool
self,
topic: str,
validator: ValidatorFn,
is_async_validator: bool,
timeout: float | None = None,
throttle_limit: int | None = None,
) -> None:
"""
Register a validator under the given topic. One topic can only have one
@ -306,8 +327,18 @@ class Pubsub(Service, IPubsub):
:param topic: the topic to register validator under
:param validator: the validator used to validate messages published to the topic
:param is_async_validator: indicate if the validator is an asynchronous validator
:param timeout: optional timeout for the validator
:param throttle_limit: optional throttle limit for the validator
""" # noqa: E501
self.topic_validators[topic] = TopicValidator(validator, is_async_validator)
# Create throttled topic validator
topic_validator = self.validation_throttler.create_topic_validator(
topic=topic,
validator=validator,
is_async=is_async_validator,
timeout=timeout,
throttle_limit=throttle_limit,
)
self.topic_validators[topic] = topic_validator
def remove_topic_validator(self, topic: str) -> None:
"""
@ -317,17 +348,18 @@ class Pubsub(Service, IPubsub):
"""
self.topic_validators.pop(topic, None)
def get_msg_validators(self, msg: rpc_pb2.Message) -> tuple[TopicValidator, ...]:
def get_msg_validators(self, msg: rpc_pb2.Message) -> list[TopicValidator]:
"""
Get all validators corresponding to the topics in the message.
:param msg: the message published to the topic
:return: list of topic validators for the message's topics
"""
return tuple(
return [
self.topic_validators[topic]
for topic in msg.topicIDs
if topic in self.topic_validators
)
]
def add_to_blacklist(self, peer_id: ID) -> None:
"""
@ -663,39 +695,56 @@ class Pubsub(Service, IPubsub):
:param msg_forwarder: the peer who forward us the message.
:param msg: the message.
"""
sync_topic_validators: list[SyncValidatorFn] = []
async_topic_validators: list[AsyncValidatorFn] = []
for topic_validator in self.get_msg_validators(msg):
if topic_validator.is_async:
async_topic_validators.append(
cast(AsyncValidatorFn, topic_validator.validator)
)
else:
sync_topic_validators.append(
cast(SyncValidatorFn, topic_validator.validator)
)
# Get applicable validators for this message
validators = self.get_msg_validators(msg)
for validator in sync_topic_validators:
if not validator(msg_forwarder, msg):
raise ValidationError(f"Validation failed for msg={msg}")
if not validators:
# No validators, accept immediately
return
# TODO: Implement throttle on async validators
# Use trio.Event for async coordination
validation_event = trio.Event()
result_container: dict[str, ValidationResult | None | Exception] = {
"result": None,
"error": None,
}
if len(async_topic_validators) > 0:
# TODO: Use a better pattern
final_result: bool = True
def handle_validation_result(
result: ValidationResult, error: Exception | None
) -> None:
result_container["result"] = result
result_container["error"] = error
validation_event.set()
async def run_async_validator(func: AsyncValidatorFn) -> None:
nonlocal final_result
result = await func(msg_forwarder, msg)
final_result = final_result and result
# Submit for throttled validation
success = await self.validation_throttler.submit_validation(
validators=validators,
msg_forwarder=msg_forwarder,
msg=msg,
result_callback=handle_validation_result,
)
async with trio.open_nursery() as nursery:
for async_validator in async_topic_validators:
nursery.start_soon(run_async_validator, async_validator)
if not success:
# Validation was throttled at queue level
raise ValidationError("Validation throttled at queue level")
if not final_result:
raise ValidationError(f"Validation failed for msg={msg}")
# Wait for validation result
await validation_event.wait()
result = result_container["result"]
error = result_container["error"]
if error:
raise ValidationError(f"Validation error: {error}")
if result == ValidationResult.REJECT:
raise ValidationError("Message validation rejected")
elif result == ValidationResult.THROTTLED:
raise ValidationError("Message validation throttled")
elif result == ValidationResult.IGNORE:
# Treat IGNORE as rejection for now, or you could silently drop
raise ValidationError("Message validation ignored")
# ACCEPT case - just return normally
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
"""

View File

@ -0,0 +1,318 @@
from collections.abc import (
Awaitable,
Callable,
)
from dataclasses import dataclass
from enum import Enum
import inspect
import logging
from typing import (
Any,
NamedTuple,
)
import trio
from libp2p.custom_types import (
ValidatorFn,
)
logger = logging.getLogger("libp2p.pubsub.validation")
class ValidationResult(Enum):
ACCEPT = "accept"
REJECT = "reject"
IGNORE = "ignore"
THROTTLED = "throttled"
@dataclass
class ValidationRequest:
"""Request for message validation"""
validators: list["TopicValidator"]
# TODO: Use a more specific type for msg_forwarder
msg_forwarder: Any # peer ID
msg: Any # message object
result_callback: Callable[[ValidationResult, Exception | None], None]
class TopicValidator(NamedTuple):
topic: str
validator: ValidatorFn
is_async: bool
timeout: float | None = None
# Per-topic throttle semaphore
throttle_semaphore: trio.Semaphore | None = None
class ValidationThrottler:
"""Manages all validation throttling mechanisms"""
def __init__(
self,
queue_size: int = 32,
global_throttle_limit: int = 8192,
default_topic_throttle_limit: int = 1024,
worker_count: int | None = None,
):
# 1. Queue-level throttling - bounded memory channel
self._validation_send, self._validation_receive = trio.open_memory_channel[
ValidationRequest
](queue_size)
# 2. Global validation throttling - limits total concurrent async validations
self._global_throttle = trio.Semaphore(global_throttle_limit)
# 3. Per-topic throttling - each validator gets its own semaphore
self._default_topic_throttle_limit = default_topic_throttle_limit
# Worker management
# TODO: Find a better way to manage worker count
self._worker_count = worker_count or 4
self._running = False
async def start(self, nursery: trio.Nursery) -> None:
"""Start the validation workers"""
self._running = True
# Start validation worker tasks
for i in range(self._worker_count):
nursery.start_soon(self._validation_worker, f"worker-{i}")
async def stop(self) -> None:
"""Stop the validation system"""
self._running = False
await self._validation_send.aclose()
def create_topic_validator(
self,
topic: str,
validator: ValidatorFn,
is_async: bool,
timeout: float | None = None,
throttle_limit: int | None = None,
) -> TopicValidator:
"""Create a new topic validator with its own throttle"""
limit = throttle_limit or self._default_topic_throttle_limit
throttle_sem = trio.Semaphore(limit)
return TopicValidator(
topic=topic,
validator=validator,
is_async=is_async,
timeout=timeout,
throttle_semaphore=throttle_sem,
)
async def submit_validation(
self,
validators: list[TopicValidator],
msg_forwarder: Any,
msg: Any,
result_callback: Callable[[ValidationResult, Exception | None], None],
) -> bool:
"""
Submit a message for validation.
Returns True if queued successfully, False if queue is full (throttled).
"""
if not self._running:
result_callback(
ValidationResult.REJECT, Exception("Validation system not running")
)
return False
request = ValidationRequest(
validators=validators,
msg_forwarder=msg_forwarder,
msg=msg,
result_callback=result_callback,
)
try:
# This will raise trio.WouldBlock if queue is full
self._validation_send.send_nowait(request)
return True
except trio.WouldBlock:
# Queue-level throttling: drop the message
logger.debug(
"Validation queue full, dropping message from %s", msg_forwarder
)
result_callback(
ValidationResult.THROTTLED, Exception("Validation queue full")
)
return False
async def _validation_worker(self, worker_id: str) -> None:
"""Worker that processes validation requests"""
logger.debug("Validation worker %s started", worker_id)
async with self._validation_receive:
async for request in self._validation_receive:
if not self._running:
break
try:
# Process the validation request
result = await self._validate_message(request)
request.result_callback(result, None)
except Exception as e:
logger.exception("Error in validation worker %s", worker_id)
request.result_callback(ValidationResult.REJECT, e)
logger.debug("Validation worker %s stopped", worker_id)
async def _validate_message(self, request: ValidationRequest) -> ValidationResult:
"""Core validation logic with throttling"""
validators = request.validators
msg_forwarder = request.msg_forwarder
msg = request.msg
if not validators:
return ValidationResult.ACCEPT
# Separate sync and async validators
sync_validators = [v for v in validators if not v.is_async]
async_validators = [v for v in validators if v.is_async]
# Run synchronous validators first
for validator in sync_validators:
try:
# Apply per-topic throttling even for sync validators
if validator.throttle_semaphore:
validator.throttle_semaphore.acquire_nowait()
try:
result = validator.validator(msg_forwarder, msg)
if not result:
return ValidationResult.REJECT
finally:
validator.throttle_semaphore.release()
else:
result = validator.validator(msg_forwarder, msg)
if not result:
return ValidationResult.REJECT
except trio.WouldBlock:
# Per-topic throttling for sync validator
logger.debug("Sync validation throttled for topic %s", validator.topic)
return ValidationResult.THROTTLED
except Exception as e:
logger.exception(
"Sync validator failed for topic %s: %s", validator.topic, e
)
return ValidationResult.REJECT
# Handle async validators with global + per-topic throttling
if async_validators:
return await self._validate_async_validators(
async_validators, msg_forwarder, msg
)
return ValidationResult.ACCEPT
async def _validate_async_validators(
self, validators: list[TopicValidator], msg_forwarder: Any, msg: Any
) -> ValidationResult:
"""Handle async validators with proper throttling"""
if len(validators) == 1:
# Fast path for single validator
return await self._validate_single_async_validator(
validators[0], msg_forwarder, msg
)
# Multiple async validators - run them concurrently
try:
# Try to acquire global throttle slot
self._global_throttle.acquire_nowait()
except trio.WouldBlock:
logger.debug(
"Global validation throttle exceeded, dropping message from %s",
msg_forwarder,
)
return ValidationResult.THROTTLED
try:
async with trio.open_nursery() as nursery:
results = {}
async def run_validator(validator: TopicValidator, index: int) -> None:
"""Run a single async validator and store the result"""
nonlocal results
result = await self._validate_single_async_validator(
validator, msg_forwarder, msg
)
results[index] = result
# Start all validators concurrently
for i, validator in enumerate(validators):
nursery.start_soon(run_validator, validator, i)
# Process results - any reject or throttle causes overall failure
final_result = ValidationResult.ACCEPT
for result in results.values():
if result == ValidationResult.REJECT:
return ValidationResult.REJECT
elif result == ValidationResult.THROTTLED:
final_result = ValidationResult.THROTTLED
elif (
result == ValidationResult.IGNORE
and final_result == ValidationResult.ACCEPT
):
final_result = ValidationResult.IGNORE
return final_result
finally:
self._global_throttle.release()
return ValidationResult.IGNORE
async def _validate_single_async_validator(
self, validator: TopicValidator, msg_forwarder: Any, msg: Any
) -> ValidationResult:
"""Validate with a single async validator"""
# Apply per-topic throttling
if validator.throttle_semaphore:
try:
validator.throttle_semaphore.acquire_nowait()
except trio.WouldBlock:
logger.debug(
"Per-topic validation throttled for topic %s", validator.topic
)
return ValidationResult.THROTTLED
else:
# Fallback if no throttle semaphore configured
pass
try:
# Apply timeout if configured
result: bool | Awaitable[bool]
if validator.timeout:
with trio.fail_after(validator.timeout):
func = validator.validator
if inspect.iscoroutinefunction(func):
result = await func(msg_forwarder, msg)
else:
result = func(msg_forwarder, msg)
else:
func = validator.validator
if inspect.iscoroutinefunction(func):
result = await func(msg_forwarder, msg)
else:
result = func(msg_forwarder, msg)
return ValidationResult.ACCEPT if result else ValidationResult.REJECT
except trio.TooSlowError:
logger.debug("Validation timeout for topic %s", validator.topic)
return ValidationResult.IGNORE
except Exception as e:
logger.exception(
"Async validator failed for topic %s: %s", validator.topic, e
)
return ValidationResult.REJECT
finally:
if validator.throttle_semaphore:
validator.throttle_semaphore.release()
return ValidationResult.IGNORE