mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
refactor: specify types for msg_forwarder and msg in ValidationRequest and related methods
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
This commit is contained in:
@ -233,7 +233,7 @@ class Pubsub(Service, IPubsub):
|
||||
await self.validation_throttler.start(nursery)
|
||||
# Keep nursery alive until service stops
|
||||
while self.manager.is_running:
|
||||
await trio.sleep(1)
|
||||
await self.manager.wait_finished()
|
||||
|
||||
@property
|
||||
def my_id(self) -> ID:
|
||||
|
||||
@ -1,20 +1,23 @@
|
||||
from collections.abc import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import inspect
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
NamedTuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.custom_types import (
|
||||
ValidatorFn,
|
||||
from libp2p.custom_types import AsyncValidatorFn, ValidatorFn
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
from .pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.pubsub.validation")
|
||||
@ -32,9 +35,8 @@ class ValidationRequest:
|
||||
"""Request for message validation"""
|
||||
|
||||
validators: list["TopicValidator"]
|
||||
# TODO: Use a more specific type for msg_forwarder
|
||||
msg_forwarder: Any # peer ID
|
||||
msg: Any # message object
|
||||
msg_forwarder: ID # peer ID
|
||||
msg: rpc_pb2.Message # message object
|
||||
result_callback: Callable[[ValidationResult, Exception | None], None]
|
||||
|
||||
|
||||
@ -109,8 +111,8 @@ class ValidationThrottler:
|
||||
async def submit_validation(
|
||||
self,
|
||||
validators: list[TopicValidator],
|
||||
msg_forwarder: Any,
|
||||
msg: Any,
|
||||
msg_forwarder: ID,
|
||||
msg: rpc_pb2.Message,
|
||||
result_callback: Callable[[ValidationResult, Exception | None], None],
|
||||
) -> bool:
|
||||
"""
|
||||
@ -211,7 +213,7 @@ class ValidationThrottler:
|
||||
return ValidationResult.ACCEPT
|
||||
|
||||
async def _validate_async_validators(
|
||||
self, validators: list[TopicValidator], msg_forwarder: Any, msg: Any
|
||||
self, validators: list[TopicValidator], msg_forwarder: ID, msg: rpc_pb2.Message
|
||||
) -> ValidationResult:
|
||||
"""Handle async validators with proper throttling"""
|
||||
if len(validators) == 1:
|
||||
@ -268,7 +270,7 @@ class ValidationThrottler:
|
||||
return ValidationResult.IGNORE
|
||||
|
||||
async def _validate_single_async_validator(
|
||||
self, validator: TopicValidator, msg_forwarder: Any, msg: Any
|
||||
self, validator: TopicValidator, msg_forwarder: ID, msg: rpc_pb2.Message
|
||||
) -> ValidationResult:
|
||||
"""Validate with a single async validator"""
|
||||
# Apply per-topic throttling
|
||||
@ -286,20 +288,14 @@ class ValidationThrottler:
|
||||
|
||||
try:
|
||||
# Apply timeout if configured
|
||||
result: bool | Awaitable[bool]
|
||||
result: bool
|
||||
if validator.timeout:
|
||||
with trio.fail_after(validator.timeout):
|
||||
func = validator.validator
|
||||
if inspect.iscoroutinefunction(func):
|
||||
result = await func(msg_forwarder, msg)
|
||||
else:
|
||||
result = func(msg_forwarder, msg)
|
||||
else:
|
||||
func = validator.validator
|
||||
if inspect.iscoroutinefunction(func):
|
||||
func = cast(AsyncValidatorFn, validator.validator)
|
||||
result = await func(msg_forwarder, msg)
|
||||
else:
|
||||
result = func(msg_forwarder, msg)
|
||||
else:
|
||||
func = cast(AsyncValidatorFn, validator.validator)
|
||||
result = await func(msg_forwarder, msg)
|
||||
|
||||
return ValidationResult.ACCEPT if result else ValidationResult.REJECT
|
||||
|
||||
|
||||
Reference in New Issue
Block a user