refactor: specify types for msg_forwarder and msg in ValidationRequest and related methods

Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
This commit is contained in:
2025-07-01 20:46:58 +05:30
parent 8d1e5fffd2
commit df23e3d899
2 changed files with 20 additions and 24 deletions

View File

@ -233,7 +233,7 @@ class Pubsub(Service, IPubsub):
await self.validation_throttler.start(nursery)
# Keep nursery alive until service stops
while self.manager.is_running:
await trio.sleep(1)
await self.manager.wait_finished()
@property
def my_id(self) -> ID:

View File

@ -1,20 +1,23 @@
from collections.abc import (
Awaitable,
Callable,
)
from dataclasses import dataclass
from enum import Enum
import inspect
import logging
from typing import (
Any,
NamedTuple,
cast,
)
import trio
from libp2p.custom_types import (
ValidatorFn,
from libp2p.custom_types import AsyncValidatorFn, ValidatorFn
from libp2p.peer.id import (
ID,
)
from .pb import (
rpc_pb2,
)
logger = logging.getLogger("libp2p.pubsub.validation")
@ -32,9 +35,8 @@ class ValidationRequest:
"""Request for message validation"""
validators: list["TopicValidator"]
# TODO: Use a more specific type for msg_forwarder
msg_forwarder: Any # peer ID
msg: Any # message object
msg_forwarder: ID # peer ID
msg: rpc_pb2.Message # message object
result_callback: Callable[[ValidationResult, Exception | None], None]
@ -109,8 +111,8 @@ class ValidationThrottler:
async def submit_validation(
self,
validators: list[TopicValidator],
msg_forwarder: Any,
msg: Any,
msg_forwarder: ID,
msg: rpc_pb2.Message,
result_callback: Callable[[ValidationResult, Exception | None], None],
) -> bool:
"""
@ -211,7 +213,7 @@ class ValidationThrottler:
return ValidationResult.ACCEPT
async def _validate_async_validators(
self, validators: list[TopicValidator], msg_forwarder: Any, msg: Any
self, validators: list[TopicValidator], msg_forwarder: ID, msg: rpc_pb2.Message
) -> ValidationResult:
"""Handle async validators with proper throttling"""
if len(validators) == 1:
@ -268,7 +270,7 @@ class ValidationThrottler:
return ValidationResult.IGNORE
async def _validate_single_async_validator(
self, validator: TopicValidator, msg_forwarder: Any, msg: Any
self, validator: TopicValidator, msg_forwarder: ID, msg: rpc_pb2.Message
) -> ValidationResult:
"""Validate with a single async validator"""
# Apply per-topic throttling
@ -286,20 +288,14 @@ class ValidationThrottler:
try:
# Apply timeout if configured
result: bool | Awaitable[bool]
result: bool
if validator.timeout:
with trio.fail_after(validator.timeout):
func = validator.validator
if inspect.iscoroutinefunction(func):
result = await func(msg_forwarder, msg)
else:
result = func(msg_forwarder, msg)
else:
func = validator.validator
if inspect.iscoroutinefunction(func):
func = cast(AsyncValidatorFn, validator.validator)
result = await func(msg_forwarder, msg)
else:
result = func(msg_forwarder, msg)
else:
func = cast(AsyncValidatorFn, validator.validator)
result = await func(msg_forwarder, msg)
return ValidationResult.ACCEPT if result else ValidationResult.REJECT