diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 8ba7d471..481c8981 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -11,10 +11,6 @@ import functools import hashlib import logging import time -from typing import ( - NamedTuple, - cast, -) import base58 import trio @@ -30,8 +26,6 @@ from libp2p.crypto.keys import ( PrivateKey, ) from libp2p.custom_types import ( - AsyncValidatorFn, - SyncValidatorFn, TProtocol, ValidatorFn, ) @@ -76,6 +70,11 @@ from .pubsub_notifee import ( from .subscription import ( TrioSubscriptionAPI, ) +from .validation_throttler import ( + TopicValidator, + ValidationResult, + ValidationThrottler, +) from .validators import ( PUBSUB_SIGNING_PREFIX, signature_validator, @@ -96,11 +95,6 @@ def get_content_addressed_msg_id(msg: rpc_pb2.Message) -> bytes: return base64.b64encode(hashlib.sha256(msg.data).digest()) -class TopicValidator(NamedTuple): - validator: ValidatorFn - is_async: bool - - class Pubsub(Service, IPubsub): host: IHost @@ -142,6 +136,11 @@ class Pubsub(Service, IPubsub): msg_id_constructor: Callable[ [rpc_pb2.Message], bytes ] = get_peer_and_seqno_msg_id, + # TODO: these values have been copied from Go, but try to tune these dynamically + validation_queue_size: int = 32, + global_throttle_limit: int = 8192, + default_topic_throttle_limit: int = 1024, + validation_worker_count: int | None = None, ) -> None: """ Construct a new Pubsub object, which is responsible for handling all @@ -202,7 +201,15 @@ class Pubsub(Service, IPubsub): # Create peers map, which maps peer_id (as string) to stream (to a given peer) self.peers = {} - # Map of topic to topic validator + # Validation Throttler + self.validation_throttler = ValidationThrottler( + queue_size=validation_queue_size, + global_throttle_limit=global_throttle_limit, + default_topic_throttle_limit=default_topic_throttle_limit, + worker_count=validation_worker_count or 4, + ) + + # Keep a mapping of topic -> TopicValidator for easier lookup self.topic_validators = {} self.counter = int(time.time()) @@ -214,10 +221,19 @@ class Pubsub(Service, IPubsub): self.event_handle_dead_peer_queue_started = trio.Event() async def run(self) -> None: + self.manager.run_daemon_task(self._start_validation_throttler) self.manager.run_daemon_task(self.handle_peer_queue) self.manager.run_daemon_task(self.handle_dead_peer_queue) await self.manager.wait_finished() + async def _start_validation_throttler(self) -> None: + """Start validation throttler in current nursery context""" + async with trio.open_nursery() as nursery: + await self.validation_throttler.start(nursery) + # Keep nursery alive until service stops + while self.manager.is_running: + await trio.sleep(1) + @property def my_id(self) -> ID: return self.host.get_id() @@ -297,7 +313,12 @@ class Pubsub(Service, IPubsub): ) def set_topic_validator( - self, topic: str, validator: ValidatorFn, is_async_validator: bool + self, + topic: str, + validator: ValidatorFn, + is_async_validator: bool, + timeout: float | None = None, + throttle_limit: int | None = None, ) -> None: """ Register a validator under the given topic. One topic can only have one @@ -306,8 +327,18 @@ class Pubsub(Service, IPubsub): :param topic: the topic to register validator under :param validator: the validator used to validate messages published to the topic :param is_async_validator: indicate if the validator is an asynchronous validator + :param timeout: optional timeout for the validator + :param throttle_limit: optional throttle limit for the validator """ # noqa: E501 - self.topic_validators[topic] = TopicValidator(validator, is_async_validator) + # Create throttled topic validator + topic_validator = self.validation_throttler.create_topic_validator( + topic=topic, + validator=validator, + is_async=is_async_validator, + timeout=timeout, + throttle_limit=throttle_limit, + ) + self.topic_validators[topic] = topic_validator def remove_topic_validator(self, topic: str) -> None: """ @@ -317,17 +348,18 @@ class Pubsub(Service, IPubsub): """ self.topic_validators.pop(topic, None) - def get_msg_validators(self, msg: rpc_pb2.Message) -> tuple[TopicValidator, ...]: + def get_msg_validators(self, msg: rpc_pb2.Message) -> list[TopicValidator]: """ Get all validators corresponding to the topics in the message. :param msg: the message published to the topic + :return: list of topic validators for the message's topics """ - return tuple( + return [ self.topic_validators[topic] for topic in msg.topicIDs if topic in self.topic_validators - ) + ] def add_to_blacklist(self, peer_id: ID) -> None: """ @@ -663,39 +695,56 @@ class Pubsub(Service, IPubsub): :param msg_forwarder: the peer who forward us the message. :param msg: the message. """ - sync_topic_validators: list[SyncValidatorFn] = [] - async_topic_validators: list[AsyncValidatorFn] = [] - for topic_validator in self.get_msg_validators(msg): - if topic_validator.is_async: - async_topic_validators.append( - cast(AsyncValidatorFn, topic_validator.validator) - ) - else: - sync_topic_validators.append( - cast(SyncValidatorFn, topic_validator.validator) - ) + # Get applicable validators for this message + validators = self.get_msg_validators(msg) - for validator in sync_topic_validators: - if not validator(msg_forwarder, msg): - raise ValidationError(f"Validation failed for msg={msg}") + if not validators: + # No validators, accept immediately + return - # TODO: Implement throttle on async validators + # Use trio.Event for async coordination + validation_event = trio.Event() + result_container: dict[str, ValidationResult | None | Exception] = { + "result": None, + "error": None, + } - if len(async_topic_validators) > 0: - # TODO: Use a better pattern - final_result: bool = True + def handle_validation_result( + result: ValidationResult, error: Exception | None + ) -> None: + result_container["result"] = result + result_container["error"] = error + validation_event.set() - async def run_async_validator(func: AsyncValidatorFn) -> None: - nonlocal final_result - result = await func(msg_forwarder, msg) - final_result = final_result and result + # Submit for throttled validation + success = await self.validation_throttler.submit_validation( + validators=validators, + msg_forwarder=msg_forwarder, + msg=msg, + result_callback=handle_validation_result, + ) - async with trio.open_nursery() as nursery: - for async_validator in async_topic_validators: - nursery.start_soon(run_async_validator, async_validator) + if not success: + # Validation was throttled at queue level + raise ValidationError("Validation throttled at queue level") - if not final_result: - raise ValidationError(f"Validation failed for msg={msg}") + # Wait for validation result + await validation_event.wait() + + result = result_container["result"] + error = result_container["error"] + + if error: + raise ValidationError(f"Validation error: {error}") + + if result == ValidationResult.REJECT: + raise ValidationError("Message validation rejected") + elif result == ValidationResult.THROTTLED: + raise ValidationError("Message validation throttled") + elif result == ValidationResult.IGNORE: + # Treat IGNORE as rejection for now, or you could silently drop + raise ValidationError("Message validation ignored") + # ACCEPT case - just return normally async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: """ diff --git a/libp2p/pubsub/validation_throttler.py b/libp2p/pubsub/validation_throttler.py new file mode 100644 index 00000000..38b4991c --- /dev/null +++ b/libp2p/pubsub/validation_throttler.py @@ -0,0 +1,318 @@ +from collections.abc import ( + Awaitable, + Callable, +) +from dataclasses import dataclass +from enum import Enum +import inspect +import logging +from typing import ( + Any, + NamedTuple, +) + +import trio + +from libp2p.custom_types import ( + ValidatorFn, +) + +logger = logging.getLogger("libp2p.pubsub.validation") + + +class ValidationResult(Enum): + ACCEPT = "accept" + REJECT = "reject" + IGNORE = "ignore" + THROTTLED = "throttled" + + +@dataclass +class ValidationRequest: + """Request for message validation""" + + validators: list["TopicValidator"] + # TODO: Use a more specific type for msg_forwarder + msg_forwarder: Any # peer ID + msg: Any # message object + result_callback: Callable[[ValidationResult, Exception | None], None] + + +class TopicValidator(NamedTuple): + topic: str + validator: ValidatorFn + is_async: bool + timeout: float | None = None + # Per-topic throttle semaphore + throttle_semaphore: trio.Semaphore | None = None + + +class ValidationThrottler: + """Manages all validation throttling mechanisms""" + + def __init__( + self, + queue_size: int = 32, + global_throttle_limit: int = 8192, + default_topic_throttle_limit: int = 1024, + worker_count: int | None = None, + ): + # 1. Queue-level throttling - bounded memory channel + self._validation_send, self._validation_receive = trio.open_memory_channel[ + ValidationRequest + ](queue_size) + + # 2. Global validation throttling - limits total concurrent async validations + self._global_throttle = trio.Semaphore(global_throttle_limit) + + # 3. Per-topic throttling - each validator gets its own semaphore + self._default_topic_throttle_limit = default_topic_throttle_limit + + # Worker management + # TODO: Find a better way to manage worker count + self._worker_count = worker_count or 4 + self._running = False + + async def start(self, nursery: trio.Nursery) -> None: + """Start the validation workers""" + self._running = True + + # Start validation worker tasks + for i in range(self._worker_count): + nursery.start_soon(self._validation_worker, f"worker-{i}") + + async def stop(self) -> None: + """Stop the validation system""" + self._running = False + await self._validation_send.aclose() + + def create_topic_validator( + self, + topic: str, + validator: ValidatorFn, + is_async: bool, + timeout: float | None = None, + throttle_limit: int | None = None, + ) -> TopicValidator: + """Create a new topic validator with its own throttle""" + limit = throttle_limit or self._default_topic_throttle_limit + throttle_sem = trio.Semaphore(limit) + + return TopicValidator( + topic=topic, + validator=validator, + is_async=is_async, + timeout=timeout, + throttle_semaphore=throttle_sem, + ) + + async def submit_validation( + self, + validators: list[TopicValidator], + msg_forwarder: Any, + msg: Any, + result_callback: Callable[[ValidationResult, Exception | None], None], + ) -> bool: + """ + Submit a message for validation. + Returns True if queued successfully, False if queue is full (throttled). + """ + if not self._running: + result_callback( + ValidationResult.REJECT, Exception("Validation system not running") + ) + return False + + request = ValidationRequest( + validators=validators, + msg_forwarder=msg_forwarder, + msg=msg, + result_callback=result_callback, + ) + + try: + # This will raise trio.WouldBlock if queue is full + self._validation_send.send_nowait(request) + return True + except trio.WouldBlock: + # Queue-level throttling: drop the message + logger.debug( + "Validation queue full, dropping message from %s", msg_forwarder + ) + result_callback( + ValidationResult.THROTTLED, Exception("Validation queue full") + ) + return False + + async def _validation_worker(self, worker_id: str) -> None: + """Worker that processes validation requests""" + logger.debug("Validation worker %s started", worker_id) + + async with self._validation_receive: + async for request in self._validation_receive: + if not self._running: + break + + try: + # Process the validation request + result = await self._validate_message(request) + request.result_callback(result, None) + except Exception as e: + logger.exception("Error in validation worker %s", worker_id) + request.result_callback(ValidationResult.REJECT, e) + + logger.debug("Validation worker %s stopped", worker_id) + + async def _validate_message(self, request: ValidationRequest) -> ValidationResult: + """Core validation logic with throttling""" + validators = request.validators + msg_forwarder = request.msg_forwarder + msg = request.msg + + if not validators: + return ValidationResult.ACCEPT + + # Separate sync and async validators + sync_validators = [v for v in validators if not v.is_async] + async_validators = [v for v in validators if v.is_async] + + # Run synchronous validators first + for validator in sync_validators: + try: + # Apply per-topic throttling even for sync validators + if validator.throttle_semaphore: + validator.throttle_semaphore.acquire_nowait() + try: + result = validator.validator(msg_forwarder, msg) + if not result: + return ValidationResult.REJECT + finally: + validator.throttle_semaphore.release() + else: + result = validator.validator(msg_forwarder, msg) + if not result: + return ValidationResult.REJECT + except trio.WouldBlock: + # Per-topic throttling for sync validator + logger.debug("Sync validation throttled for topic %s", validator.topic) + return ValidationResult.THROTTLED + except Exception as e: + logger.exception( + "Sync validator failed for topic %s: %s", validator.topic, e + ) + return ValidationResult.REJECT + + # Handle async validators with global + per-topic throttling + if async_validators: + return await self._validate_async_validators( + async_validators, msg_forwarder, msg + ) + + return ValidationResult.ACCEPT + + async def _validate_async_validators( + self, validators: list[TopicValidator], msg_forwarder: Any, msg: Any + ) -> ValidationResult: + """Handle async validators with proper throttling""" + if len(validators) == 1: + # Fast path for single validator + return await self._validate_single_async_validator( + validators[0], msg_forwarder, msg + ) + + # Multiple async validators - run them concurrently + try: + # Try to acquire global throttle slot + self._global_throttle.acquire_nowait() + except trio.WouldBlock: + logger.debug( + "Global validation throttle exceeded, dropping message from %s", + msg_forwarder, + ) + return ValidationResult.THROTTLED + + try: + async with trio.open_nursery() as nursery: + results = {} + + async def run_validator(validator: TopicValidator, index: int) -> None: + """Run a single async validator and store the result""" + nonlocal results + result = await self._validate_single_async_validator( + validator, msg_forwarder, msg + ) + results[index] = result + + # Start all validators concurrently + for i, validator in enumerate(validators): + nursery.start_soon(run_validator, validator, i) + + # Process results - any reject or throttle causes overall failure + final_result = ValidationResult.ACCEPT + for result in results.values(): + if result == ValidationResult.REJECT: + return ValidationResult.REJECT + elif result == ValidationResult.THROTTLED: + final_result = ValidationResult.THROTTLED + elif ( + result == ValidationResult.IGNORE + and final_result == ValidationResult.ACCEPT + ): + final_result = ValidationResult.IGNORE + + return final_result + + finally: + self._global_throttle.release() + + return ValidationResult.IGNORE + + async def _validate_single_async_validator( + self, validator: TopicValidator, msg_forwarder: Any, msg: Any + ) -> ValidationResult: + """Validate with a single async validator""" + # Apply per-topic throttling + if validator.throttle_semaphore: + try: + validator.throttle_semaphore.acquire_nowait() + except trio.WouldBlock: + logger.debug( + "Per-topic validation throttled for topic %s", validator.topic + ) + return ValidationResult.THROTTLED + else: + # Fallback if no throttle semaphore configured + pass + + try: + # Apply timeout if configured + result: bool | Awaitable[bool] + if validator.timeout: + with trio.fail_after(validator.timeout): + func = validator.validator + if inspect.iscoroutinefunction(func): + result = await func(msg_forwarder, msg) + else: + result = func(msg_forwarder, msg) + else: + func = validator.validator + if inspect.iscoroutinefunction(func): + result = await func(msg_forwarder, msg) + else: + result = func(msg_forwarder, msg) + + return ValidationResult.ACCEPT if result else ValidationResult.REJECT + + except trio.TooSlowError: + logger.debug("Validation timeout for topic %s", validator.topic) + return ValidationResult.IGNORE + except Exception as e: + logger.exception( + "Async validator failed for topic %s: %s", validator.topic, e + ) + return ValidationResult.REJECT + finally: + if validator.throttle_semaphore: + validator.throttle_semaphore.release() + + return ValidationResult.IGNORE