refactor: specify types for msg_forwarder and msg in ValidationRequest and related methods

Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
This commit is contained in:
2025-07-01 20:46:58 +05:30
parent 8d1e5fffd2
commit df23e3d899
2 changed files with 20 additions and 24 deletions

View File

@ -233,7 +233,7 @@ class Pubsub(Service, IPubsub):
await self.validation_throttler.start(nursery) await self.validation_throttler.start(nursery)
# Keep nursery alive until service stops # Keep nursery alive until service stops
while self.manager.is_running: while self.manager.is_running:
await trio.sleep(1) await self.manager.wait_finished()
@property @property
def my_id(self) -> ID: def my_id(self) -> ID:

View File

@ -1,20 +1,23 @@
from collections.abc import ( from collections.abc import (
Awaitable,
Callable, Callable,
) )
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
import inspect
import logging import logging
from typing import ( from typing import (
Any,
NamedTuple, NamedTuple,
cast,
) )
import trio import trio
from libp2p.custom_types import ( from libp2p.custom_types import AsyncValidatorFn, ValidatorFn
ValidatorFn, from libp2p.peer.id import (
ID,
)
from .pb import (
rpc_pb2,
) )
logger = logging.getLogger("libp2p.pubsub.validation") logger = logging.getLogger("libp2p.pubsub.validation")
@ -32,9 +35,8 @@ class ValidationRequest:
"""Request for message validation""" """Request for message validation"""
validators: list["TopicValidator"] validators: list["TopicValidator"]
# TODO: Use a more specific type for msg_forwarder msg_forwarder: ID # peer ID
msg_forwarder: Any # peer ID msg: rpc_pb2.Message # message object
msg: Any # message object
result_callback: Callable[[ValidationResult, Exception | None], None] result_callback: Callable[[ValidationResult, Exception | None], None]
@ -109,8 +111,8 @@ class ValidationThrottler:
async def submit_validation( async def submit_validation(
self, self,
validators: list[TopicValidator], validators: list[TopicValidator],
msg_forwarder: Any, msg_forwarder: ID,
msg: Any, msg: rpc_pb2.Message,
result_callback: Callable[[ValidationResult, Exception | None], None], result_callback: Callable[[ValidationResult, Exception | None], None],
) -> bool: ) -> bool:
""" """
@ -211,7 +213,7 @@ class ValidationThrottler:
return ValidationResult.ACCEPT return ValidationResult.ACCEPT
async def _validate_async_validators( async def _validate_async_validators(
self, validators: list[TopicValidator], msg_forwarder: Any, msg: Any self, validators: list[TopicValidator], msg_forwarder: ID, msg: rpc_pb2.Message
) -> ValidationResult: ) -> ValidationResult:
"""Handle async validators with proper throttling""" """Handle async validators with proper throttling"""
if len(validators) == 1: if len(validators) == 1:
@ -268,7 +270,7 @@ class ValidationThrottler:
return ValidationResult.IGNORE return ValidationResult.IGNORE
async def _validate_single_async_validator( async def _validate_single_async_validator(
self, validator: TopicValidator, msg_forwarder: Any, msg: Any self, validator: TopicValidator, msg_forwarder: ID, msg: rpc_pb2.Message
) -> ValidationResult: ) -> ValidationResult:
"""Validate with a single async validator""" """Validate with a single async validator"""
# Apply per-topic throttling # Apply per-topic throttling
@ -286,20 +288,14 @@ class ValidationThrottler:
try: try:
# Apply timeout if configured # Apply timeout if configured
result: bool | Awaitable[bool] result: bool
if validator.timeout: if validator.timeout:
with trio.fail_after(validator.timeout): with trio.fail_after(validator.timeout):
func = validator.validator func = cast(AsyncValidatorFn, validator.validator)
if inspect.iscoroutinefunction(func):
result = await func(msg_forwarder, msg)
else:
result = func(msg_forwarder, msg)
else:
func = validator.validator
if inspect.iscoroutinefunction(func):
result = await func(msg_forwarder, msg) result = await func(msg_forwarder, msg)
else: else:
result = func(msg_forwarder, msg) func = cast(AsyncValidatorFn, validator.validator)
result = await func(msg_forwarder, msg)
return ValidationResult.ACCEPT if result else ValidationResult.REJECT return ValidationResult.ACCEPT if result else ValidationResult.REJECT