mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Compare commits
12 Commits
80b58a2ae0
...
async-vali
| Author | SHA1 | Date | |
|---|---|---|---|
| 399a52bc54 | |||
| 50a1228d27 | |||
| 8655ed38be | |||
| 3410f65b0a | |||
| f9cafe1e50 | |||
| 741dd3993b | |||
| 8681286f73 | |||
| ae26d59a43 | |||
| d0ef290a2a | |||
| df23e3d899 | |||
| 8d1e5fffd2 | |||
| 3431efe97f |
@ -11,10 +11,6 @@ import functools
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import (
|
|
||||||
NamedTuple,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
|
|
||||||
import base58
|
import base58
|
||||||
import trio
|
import trio
|
||||||
@ -30,8 +26,6 @@ from libp2p.crypto.keys import (
|
|||||||
PrivateKey,
|
PrivateKey,
|
||||||
)
|
)
|
||||||
from libp2p.custom_types import (
|
from libp2p.custom_types import (
|
||||||
AsyncValidatorFn,
|
|
||||||
SyncValidatorFn,
|
|
||||||
TProtocol,
|
TProtocol,
|
||||||
ValidatorFn,
|
ValidatorFn,
|
||||||
)
|
)
|
||||||
@ -77,6 +71,11 @@ from .pubsub_notifee import (
|
|||||||
from .subscription import (
|
from .subscription import (
|
||||||
TrioSubscriptionAPI,
|
TrioSubscriptionAPI,
|
||||||
)
|
)
|
||||||
|
from .validation_throttler import (
|
||||||
|
TopicValidator,
|
||||||
|
ValidationResult,
|
||||||
|
ValidationThrottler,
|
||||||
|
)
|
||||||
from .validators import (
|
from .validators import (
|
||||||
PUBSUB_SIGNING_PREFIX,
|
PUBSUB_SIGNING_PREFIX,
|
||||||
signature_validator,
|
signature_validator,
|
||||||
@ -97,11 +96,6 @@ def get_content_addressed_msg_id(msg: rpc_pb2.Message) -> bytes:
|
|||||||
return base64.b64encode(hashlib.sha256(msg.data).digest())
|
return base64.b64encode(hashlib.sha256(msg.data).digest())
|
||||||
|
|
||||||
|
|
||||||
class TopicValidator(NamedTuple):
|
|
||||||
validator: ValidatorFn
|
|
||||||
is_async: bool
|
|
||||||
|
|
||||||
|
|
||||||
class Pubsub(Service, IPubsub):
|
class Pubsub(Service, IPubsub):
|
||||||
host: IHost
|
host: IHost
|
||||||
|
|
||||||
@ -143,6 +137,11 @@ class Pubsub(Service, IPubsub):
|
|||||||
msg_id_constructor: Callable[
|
msg_id_constructor: Callable[
|
||||||
[rpc_pb2.Message], bytes
|
[rpc_pb2.Message], bytes
|
||||||
] = get_peer_and_seqno_msg_id,
|
] = get_peer_and_seqno_msg_id,
|
||||||
|
# TODO: these values have been copied from Go, but try to tune these dynamically
|
||||||
|
validation_queue_size: int = 32,
|
||||||
|
global_throttle_limit: int = 8192,
|
||||||
|
default_topic_throttle_limit: int = 1024,
|
||||||
|
validation_worker_count: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Construct a new Pubsub object, which is responsible for handling all
|
Construct a new Pubsub object, which is responsible for handling all
|
||||||
@ -203,7 +202,15 @@ class Pubsub(Service, IPubsub):
|
|||||||
# Create peers map, which maps peer_id (as string) to stream (to a given peer)
|
# Create peers map, which maps peer_id (as string) to stream (to a given peer)
|
||||||
self.peers = {}
|
self.peers = {}
|
||||||
|
|
||||||
# Map of topic to topic validator
|
# Validation Throttler
|
||||||
|
self.validation_throttler = ValidationThrottler(
|
||||||
|
queue_size=validation_queue_size,
|
||||||
|
global_throttle_limit=global_throttle_limit,
|
||||||
|
default_topic_throttle_limit=default_topic_throttle_limit,
|
||||||
|
worker_count=validation_worker_count or 4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep a mapping of topic -> TopicValidator for easier lookup
|
||||||
self.topic_validators = {}
|
self.topic_validators = {}
|
||||||
|
|
||||||
self.counter = int(time.time())
|
self.counter = int(time.time())
|
||||||
@ -215,10 +222,19 @@ class Pubsub(Service, IPubsub):
|
|||||||
self.event_handle_dead_peer_queue_started = trio.Event()
|
self.event_handle_dead_peer_queue_started = trio.Event()
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
|
self.manager.run_daemon_task(self._start_validation_throttler)
|
||||||
self.manager.run_daemon_task(self.handle_peer_queue)
|
self.manager.run_daemon_task(self.handle_peer_queue)
|
||||||
self.manager.run_daemon_task(self.handle_dead_peer_queue)
|
self.manager.run_daemon_task(self.handle_dead_peer_queue)
|
||||||
await self.manager.wait_finished()
|
await self.manager.wait_finished()
|
||||||
|
|
||||||
|
async def _start_validation_throttler(self) -> None:
|
||||||
|
"""Start validation throttler in current nursery context"""
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
await self.validation_throttler.start(nursery)
|
||||||
|
# Keep nursery alive until service stops
|
||||||
|
while self.manager.is_running:
|
||||||
|
await self.manager.wait_finished()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def my_id(self) -> ID:
|
def my_id(self) -> ID:
|
||||||
return self.host.get_id()
|
return self.host.get_id()
|
||||||
@ -298,7 +314,12 @@ class Pubsub(Service, IPubsub):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def set_topic_validator(
|
def set_topic_validator(
|
||||||
self, topic: str, validator: ValidatorFn, is_async_validator: bool
|
self,
|
||||||
|
topic: str,
|
||||||
|
validator: ValidatorFn,
|
||||||
|
is_async_validator: bool,
|
||||||
|
timeout: float | None = None,
|
||||||
|
throttle_limit: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Register a validator under the given topic. One topic can only have one
|
Register a validator under the given topic. One topic can only have one
|
||||||
@ -307,8 +328,18 @@ class Pubsub(Service, IPubsub):
|
|||||||
:param topic: the topic to register validator under
|
:param topic: the topic to register validator under
|
||||||
:param validator: the validator used to validate messages published to the topic
|
:param validator: the validator used to validate messages published to the topic
|
||||||
:param is_async_validator: indicate if the validator is an asynchronous validator
|
:param is_async_validator: indicate if the validator is an asynchronous validator
|
||||||
|
:param timeout: optional timeout for the validator
|
||||||
|
:param throttle_limit: optional throttle limit for the validator
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
self.topic_validators[topic] = TopicValidator(validator, is_async_validator)
|
# Create throttled topic validator
|
||||||
|
topic_validator = self.validation_throttler.create_topic_validator(
|
||||||
|
topic=topic,
|
||||||
|
validator=validator,
|
||||||
|
is_async=is_async_validator,
|
||||||
|
timeout=timeout,
|
||||||
|
throttle_limit=throttle_limit,
|
||||||
|
)
|
||||||
|
self.topic_validators[topic] = topic_validator
|
||||||
|
|
||||||
def remove_topic_validator(self, topic: str) -> None:
|
def remove_topic_validator(self, topic: str) -> None:
|
||||||
"""
|
"""
|
||||||
@ -318,17 +349,18 @@ class Pubsub(Service, IPubsub):
|
|||||||
"""
|
"""
|
||||||
self.topic_validators.pop(topic, None)
|
self.topic_validators.pop(topic, None)
|
||||||
|
|
||||||
def get_msg_validators(self, msg: rpc_pb2.Message) -> tuple[TopicValidator, ...]:
|
def get_msg_validators(self, msg: rpc_pb2.Message) -> list[TopicValidator]:
|
||||||
"""
|
"""
|
||||||
Get all validators corresponding to the topics in the message.
|
Get all validators corresponding to the topics in the message.
|
||||||
|
|
||||||
:param msg: the message published to the topic
|
:param msg: the message published to the topic
|
||||||
|
:return: list of topic validators for the message's topics
|
||||||
"""
|
"""
|
||||||
return tuple(
|
return [
|
||||||
self.topic_validators[topic]
|
self.topic_validators[topic]
|
||||||
for topic in msg.topicIDs
|
for topic in msg.topicIDs
|
||||||
if topic in self.topic_validators
|
if topic in self.topic_validators
|
||||||
)
|
]
|
||||||
|
|
||||||
def add_to_blacklist(self, peer_id: ID) -> None:
|
def add_to_blacklist(self, peer_id: ID) -> None:
|
||||||
"""
|
"""
|
||||||
@ -664,38 +696,56 @@ class Pubsub(Service, IPubsub):
|
|||||||
:param msg_forwarder: the peer who forward us the message.
|
:param msg_forwarder: the peer who forward us the message.
|
||||||
:param msg: the message.
|
:param msg: the message.
|
||||||
"""
|
"""
|
||||||
sync_topic_validators: list[SyncValidatorFn] = []
|
# Get applicable validators for this message
|
||||||
async_topic_validators: list[AsyncValidatorFn] = []
|
validators = self.get_msg_validators(msg)
|
||||||
for topic_validator in self.get_msg_validators(msg):
|
|
||||||
if topic_validator.is_async:
|
if not validators:
|
||||||
async_topic_validators.append(
|
# No validators, accept immediately
|
||||||
cast(AsyncValidatorFn, topic_validator.validator)
|
return
|
||||||
)
|
|
||||||
else:
|
# Use trio.Event for async coordination
|
||||||
sync_topic_validators.append(
|
validation_event = trio.Event()
|
||||||
cast(SyncValidatorFn, topic_validator.validator)
|
result_container: dict[str, ValidationResult | None | Exception] = {
|
||||||
|
"result": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def handle_validation_result(
|
||||||
|
result: ValidationResult, error: Exception | None
|
||||||
|
) -> None:
|
||||||
|
result_container["result"] = result
|
||||||
|
result_container["error"] = error
|
||||||
|
validation_event.set()
|
||||||
|
|
||||||
|
# Submit for throttled validation
|
||||||
|
success = await self.validation_throttler.submit_validation(
|
||||||
|
validators=validators,
|
||||||
|
msg_forwarder=msg_forwarder,
|
||||||
|
msg=msg,
|
||||||
|
result_callback=handle_validation_result,
|
||||||
)
|
)
|
||||||
|
|
||||||
for validator in sync_topic_validators:
|
if not success:
|
||||||
if not validator(msg_forwarder, msg):
|
# Validation was throttled at queue level
|
||||||
raise ValidationError(f"Validation failed for msg={msg}")
|
raise ValidationError("Validation throttled at queue level")
|
||||||
|
|
||||||
# TODO: Implement throttle on async validators
|
# Wait for validation result
|
||||||
|
await validation_event.wait()
|
||||||
|
|
||||||
if len(async_topic_validators) > 0:
|
result = result_container["result"]
|
||||||
# Appends to lists are thread safe in CPython
|
error = result_container["error"]
|
||||||
results = []
|
|
||||||
|
|
||||||
async def run_async_validator(func: AsyncValidatorFn) -> None:
|
if error:
|
||||||
result = await func(msg_forwarder, msg)
|
raise ValidationError(f"Validation error: {error}")
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
if result == ValidationResult.REJECT:
|
||||||
for async_validator in async_topic_validators:
|
raise ValidationError("Message validation rejected")
|
||||||
nursery.start_soon(run_async_validator, async_validator)
|
elif result == ValidationResult.THROTTLED:
|
||||||
|
raise ValidationError("Message validation throttled")
|
||||||
if not all(results):
|
elif result == ValidationResult.IGNORE:
|
||||||
raise ValidationError(f"Validation failed for msg={msg}")
|
# Treat IGNORE as rejection for now, or you could silently drop
|
||||||
|
raise ValidationError("Message validation ignored")
|
||||||
|
# ACCEPT case - just return normally
|
||||||
|
|
||||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
314
libp2p/pubsub/validation_throttler.py
Normal file
314
libp2p/pubsub/validation_throttler.py
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
from collections.abc import (
|
||||||
|
Callable,
|
||||||
|
)
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
from typing import (
|
||||||
|
NamedTuple,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from libp2p.custom_types import AsyncValidatorFn, ValidatorFn
|
||||||
|
from libp2p.peer.id import (
|
||||||
|
ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .pb import (
|
||||||
|
rpc_pb2,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("libp2p.pubsub.validation")
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationResult(Enum):
|
||||||
|
ACCEPT = "accept"
|
||||||
|
REJECT = "reject"
|
||||||
|
IGNORE = "ignore"
|
||||||
|
THROTTLED = "throttled"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationRequest:
|
||||||
|
"""Request for message validation"""
|
||||||
|
|
||||||
|
validators: list["TopicValidator"]
|
||||||
|
msg_forwarder: ID # peer ID
|
||||||
|
msg: rpc_pb2.Message # message object
|
||||||
|
result_callback: Callable[[ValidationResult, Exception | None], None]
|
||||||
|
|
||||||
|
|
||||||
|
class TopicValidator(NamedTuple):
|
||||||
|
topic: str
|
||||||
|
validator: ValidatorFn
|
||||||
|
is_async: bool
|
||||||
|
timeout: float | None = None
|
||||||
|
# Per-topic throttle semaphore
|
||||||
|
throttle_semaphore: trio.Semaphore | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationThrottler:
|
||||||
|
"""Manages all validation throttling mechanisms"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
queue_size: int = 32,
|
||||||
|
global_throttle_limit: int = 8192,
|
||||||
|
default_topic_throttle_limit: int = 1024,
|
||||||
|
worker_count: int | None = None,
|
||||||
|
):
|
||||||
|
# 1. Queue-level throttling - bounded memory channel
|
||||||
|
self._validation_send, self._validation_receive = trio.open_memory_channel[
|
||||||
|
ValidationRequest
|
||||||
|
](queue_size)
|
||||||
|
|
||||||
|
# 2. Global validation throttling - limits total concurrent async validations
|
||||||
|
self._global_throttle = trio.Semaphore(global_throttle_limit)
|
||||||
|
|
||||||
|
# 3. Per-topic throttling - each validator gets its own semaphore
|
||||||
|
self._default_topic_throttle_limit = default_topic_throttle_limit
|
||||||
|
|
||||||
|
# Worker management
|
||||||
|
# TODO: Find a better way to manage worker count
|
||||||
|
self._worker_count = worker_count or 4
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
async def start(self, nursery: trio.Nursery) -> None:
|
||||||
|
"""Start the validation workers"""
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
# Start validation worker tasks
|
||||||
|
for i in range(self._worker_count):
|
||||||
|
nursery.start_soon(self._validation_worker, f"worker-{i}")
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the validation system"""
|
||||||
|
self._running = False
|
||||||
|
await self._validation_send.aclose()
|
||||||
|
|
||||||
|
def create_topic_validator(
|
||||||
|
self,
|
||||||
|
topic: str,
|
||||||
|
validator: ValidatorFn,
|
||||||
|
is_async: bool,
|
||||||
|
timeout: float | None = None,
|
||||||
|
throttle_limit: int | None = None,
|
||||||
|
) -> TopicValidator:
|
||||||
|
"""Create a new topic validator with its own throttle"""
|
||||||
|
limit = throttle_limit or self._default_topic_throttle_limit
|
||||||
|
throttle_sem = trio.Semaphore(limit)
|
||||||
|
|
||||||
|
return TopicValidator(
|
||||||
|
topic=topic,
|
||||||
|
validator=validator,
|
||||||
|
is_async=is_async,
|
||||||
|
timeout=timeout,
|
||||||
|
throttle_semaphore=throttle_sem,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def submit_validation(
|
||||||
|
self,
|
||||||
|
validators: list[TopicValidator],
|
||||||
|
msg_forwarder: ID,
|
||||||
|
msg: rpc_pb2.Message,
|
||||||
|
result_callback: Callable[[ValidationResult, Exception | None], None],
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Submit a message for validation.
|
||||||
|
Returns True if queued successfully, False if queue is full (throttled).
|
||||||
|
"""
|
||||||
|
if not self._running:
|
||||||
|
result_callback(
|
||||||
|
ValidationResult.REJECT, Exception("Validation system not running")
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
request = ValidationRequest(
|
||||||
|
validators=validators,
|
||||||
|
msg_forwarder=msg_forwarder,
|
||||||
|
msg=msg,
|
||||||
|
result_callback=result_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# This will raise trio.WouldBlock if queue is full
|
||||||
|
self._validation_send.send_nowait(request)
|
||||||
|
return True
|
||||||
|
except trio.WouldBlock:
|
||||||
|
# Queue-level throttling: drop the message
|
||||||
|
logger.debug(
|
||||||
|
"Validation queue full, dropping message from %s", msg_forwarder
|
||||||
|
)
|
||||||
|
result_callback(
|
||||||
|
ValidationResult.THROTTLED, Exception("Validation queue full")
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _validation_worker(self, worker_id: str) -> None:
|
||||||
|
"""Worker that processes validation requests"""
|
||||||
|
logger.debug("Validation worker %s started", worker_id)
|
||||||
|
|
||||||
|
async with self._validation_receive:
|
||||||
|
async for request in self._validation_receive:
|
||||||
|
if not self._running:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process the validation request
|
||||||
|
result = await self._validate_message(request)
|
||||||
|
request.result_callback(result, None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error in validation worker %s", worker_id)
|
||||||
|
request.result_callback(ValidationResult.REJECT, e)
|
||||||
|
|
||||||
|
logger.debug("Validation worker %s stopped", worker_id)
|
||||||
|
|
||||||
|
async def _validate_message(self, request: ValidationRequest) -> ValidationResult:
|
||||||
|
"""Core validation logic with throttling"""
|
||||||
|
validators = request.validators
|
||||||
|
msg_forwarder = request.msg_forwarder
|
||||||
|
msg = request.msg
|
||||||
|
|
||||||
|
if not validators:
|
||||||
|
return ValidationResult.ACCEPT
|
||||||
|
|
||||||
|
# Separate sync and async validators
|
||||||
|
sync_validators = [v for v in validators if not v.is_async]
|
||||||
|
async_validators = [v for v in validators if v.is_async]
|
||||||
|
|
||||||
|
# Run synchronous validators first
|
||||||
|
for validator in sync_validators:
|
||||||
|
try:
|
||||||
|
# Apply per-topic throttling even for sync validators
|
||||||
|
if validator.throttle_semaphore:
|
||||||
|
validator.throttle_semaphore.acquire_nowait()
|
||||||
|
try:
|
||||||
|
result = validator.validator(msg_forwarder, msg)
|
||||||
|
if not result:
|
||||||
|
return ValidationResult.REJECT
|
||||||
|
finally:
|
||||||
|
validator.throttle_semaphore.release()
|
||||||
|
else:
|
||||||
|
result = validator.validator(msg_forwarder, msg)
|
||||||
|
if not result:
|
||||||
|
return ValidationResult.REJECT
|
||||||
|
except trio.WouldBlock:
|
||||||
|
# Per-topic throttling for sync validator
|
||||||
|
logger.debug("Sync validation throttled for topic %s", validator.topic)
|
||||||
|
return ValidationResult.THROTTLED
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
"Sync validator failed for topic %s: %s", validator.topic, e
|
||||||
|
)
|
||||||
|
return ValidationResult.REJECT
|
||||||
|
|
||||||
|
# Handle async validators with global + per-topic throttling
|
||||||
|
if async_validators:
|
||||||
|
return await self._validate_async_validators(
|
||||||
|
async_validators, msg_forwarder, msg
|
||||||
|
)
|
||||||
|
|
||||||
|
return ValidationResult.ACCEPT
|
||||||
|
|
||||||
|
async def _validate_async_validators(
|
||||||
|
self, validators: list[TopicValidator], msg_forwarder: ID, msg: rpc_pb2.Message
|
||||||
|
) -> ValidationResult:
|
||||||
|
"""Handle async validators with proper throttling"""
|
||||||
|
if len(validators) == 1:
|
||||||
|
# Fast path for single validator
|
||||||
|
return await self._validate_single_async_validator(
|
||||||
|
validators[0], msg_forwarder, msg
|
||||||
|
)
|
||||||
|
|
||||||
|
# Multiple async validators - run them concurrently
|
||||||
|
try:
|
||||||
|
# Try to acquire global throttle slot
|
||||||
|
self._global_throttle.acquire_nowait()
|
||||||
|
except trio.WouldBlock:
|
||||||
|
logger.debug(
|
||||||
|
"Global validation throttle exceeded, dropping message from %s",
|
||||||
|
msg_forwarder,
|
||||||
|
)
|
||||||
|
return ValidationResult.THROTTLED
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
async def run_validator(validator: TopicValidator, index: int) -> None:
|
||||||
|
"""Run a single async validator and store the result"""
|
||||||
|
nonlocal results
|
||||||
|
result = await self._validate_single_async_validator(
|
||||||
|
validator, msg_forwarder, msg
|
||||||
|
)
|
||||||
|
results[index] = result
|
||||||
|
|
||||||
|
# Start all validators concurrently
|
||||||
|
for i, validator in enumerate(validators):
|
||||||
|
nursery.start_soon(run_validator, validator, i)
|
||||||
|
|
||||||
|
# Process results - any reject or throttle causes overall failure
|
||||||
|
final_result = ValidationResult.ACCEPT
|
||||||
|
for result in results.values():
|
||||||
|
if result == ValidationResult.REJECT:
|
||||||
|
return ValidationResult.REJECT
|
||||||
|
elif result == ValidationResult.THROTTLED:
|
||||||
|
final_result = ValidationResult.THROTTLED
|
||||||
|
elif (
|
||||||
|
result == ValidationResult.IGNORE
|
||||||
|
and final_result == ValidationResult.ACCEPT
|
||||||
|
):
|
||||||
|
final_result = ValidationResult.IGNORE
|
||||||
|
|
||||||
|
return final_result
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._global_throttle.release()
|
||||||
|
|
||||||
|
return ValidationResult.IGNORE
|
||||||
|
|
||||||
|
async def _validate_single_async_validator(
|
||||||
|
self, validator: TopicValidator, msg_forwarder: ID, msg: rpc_pb2.Message
|
||||||
|
) -> ValidationResult:
|
||||||
|
"""Validate with a single async validator"""
|
||||||
|
# Apply per-topic throttling
|
||||||
|
if validator.throttle_semaphore:
|
||||||
|
try:
|
||||||
|
validator.throttle_semaphore.acquire_nowait()
|
||||||
|
except trio.WouldBlock:
|
||||||
|
logger.debug(
|
||||||
|
"Per-topic validation throttled for topic %s", validator.topic
|
||||||
|
)
|
||||||
|
return ValidationResult.THROTTLED
|
||||||
|
else:
|
||||||
|
# Fallback if no throttle semaphore configured
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply timeout if configured
|
||||||
|
result: bool
|
||||||
|
if validator.timeout:
|
||||||
|
with trio.fail_after(validator.timeout):
|
||||||
|
func = cast(AsyncValidatorFn, validator.validator)
|
||||||
|
result = await func(msg_forwarder, msg)
|
||||||
|
else:
|
||||||
|
func = cast(AsyncValidatorFn, validator.validator)
|
||||||
|
result = await func(msg_forwarder, msg)
|
||||||
|
|
||||||
|
return ValidationResult.ACCEPT if result else ValidationResult.REJECT
|
||||||
|
|
||||||
|
except trio.TooSlowError:
|
||||||
|
logger.debug("Validation timeout for topic %s", validator.topic)
|
||||||
|
return ValidationResult.IGNORE
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
"Async validator failed for topic %s: %s", validator.topic, e
|
||||||
|
)
|
||||||
|
return ValidationResult.REJECT
|
||||||
|
finally:
|
||||||
|
if validator.throttle_semaphore:
|
||||||
|
validator.throttle_semaphore.release()
|
||||||
|
|
||||||
|
return ValidationResult.IGNORE
|
||||||
Reference in New Issue
Block a user