diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 1614bedc..e00dc945 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -233,7 +233,7 @@ class Pubsub(Service, IPubsub): await self.validation_throttler.start(nursery) # Keep nursery alive until service stops while self.manager.is_running: - await trio.sleep(1) + await self.manager.wait_finished() @property def my_id(self) -> ID: diff --git a/libp2p/pubsub/validation_throttler.py b/libp2p/pubsub/validation_throttler.py index 38b4991c..cea4ce27 100644 --- a/libp2p/pubsub/validation_throttler.py +++ b/libp2p/pubsub/validation_throttler.py @@ -1,20 +1,23 @@ from collections.abc import ( - Awaitable, Callable, ) from dataclasses import dataclass from enum import Enum -import inspect import logging from typing import ( - Any, NamedTuple, + cast, ) import trio -from libp2p.custom_types import ( - ValidatorFn, +from libp2p.custom_types import AsyncValidatorFn, ValidatorFn +from libp2p.peer.id import ( + ID, +) + +from .pb import ( + rpc_pb2, ) logger = logging.getLogger("libp2p.pubsub.validation") @@ -32,9 +35,8 @@ class ValidationRequest: """Request for message validation""" validators: list["TopicValidator"] - # TODO: Use a more specific type for msg_forwarder - msg_forwarder: Any # peer ID - msg: Any # message object + msg_forwarder: ID # peer ID + msg: rpc_pb2.Message # message object result_callback: Callable[[ValidationResult, Exception | None], None] @@ -109,8 +111,8 @@ class ValidationThrottler: async def submit_validation( self, validators: list[TopicValidator], - msg_forwarder: Any, - msg: Any, + msg_forwarder: ID, + msg: rpc_pb2.Message, result_callback: Callable[[ValidationResult, Exception | None], None], ) -> bool: """ @@ -211,7 +213,7 @@ class ValidationThrottler: return ValidationResult.ACCEPT async def _validate_async_validators( - self, validators: list[TopicValidator], msg_forwarder: Any, msg: Any + self, validators: list[TopicValidator], msg_forwarder: ID, msg: rpc_pb2.Message ) -> ValidationResult: """Handle async validators with proper throttling""" if len(validators) == 1: @@ -268,7 +270,7 @@ class ValidationThrottler: return ValidationResult.IGNORE async def _validate_single_async_validator( - self, validator: TopicValidator, msg_forwarder: Any, msg: Any + self, validator: TopicValidator, msg_forwarder: ID, msg: rpc_pb2.Message ) -> ValidationResult: """Validate with a single async validator""" # Apply per-topic throttling @@ -286,20 +288,14 @@ class ValidationThrottler: try: # Apply timeout if configured - result: bool | Awaitable[bool] + result: bool if validator.timeout: with trio.fail_after(validator.timeout): - func = validator.validator - if inspect.iscoroutinefunction(func): - result = await func(msg_forwarder, msg) - else: - result = func(msg_forwarder, msg) - else: - func = validator.validator - if inspect.iscoroutinefunction(func): + func = cast(AsyncValidatorFn, validator.validator) result = await func(msg_forwarder, msg) - else: - result = func(msg_forwarder, msg) + else: + func = cast(AsyncValidatorFn, validator.validator) + result = await func(msg_forwarder, msg) return ValidationResult.ACCEPT if result else ValidationResult.REJECT