mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
refactor: specify types for msg_forwarder and msg in ValidationRequest and related methods
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
This commit is contained in:
@ -233,7 +233,7 @@ class Pubsub(Service, IPubsub):
|
|||||||
await self.validation_throttler.start(nursery)
|
await self.validation_throttler.start(nursery)
|
||||||
# Keep nursery alive until service stops
|
# Keep nursery alive until service stops
|
||||||
while self.manager.is_running:
|
while self.manager.is_running:
|
||||||
await trio.sleep(1)
|
await self.manager.wait_finished()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def my_id(self) -> ID:
|
def my_id(self) -> ID:
|
||||||
|
|||||||
@ -1,20 +1,23 @@
|
|||||||
from collections.abc import (
|
from collections.abc import (
|
||||||
Awaitable,
|
|
||||||
Callable,
|
Callable,
|
||||||
)
|
)
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p.custom_types import (
|
from libp2p.custom_types import AsyncValidatorFn, ValidatorFn
|
||||||
ValidatorFn,
|
from libp2p.peer.id import (
|
||||||
|
ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .pb import (
|
||||||
|
rpc_pb2,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger("libp2p.pubsub.validation")
|
logger = logging.getLogger("libp2p.pubsub.validation")
|
||||||
@ -32,9 +35,8 @@ class ValidationRequest:
|
|||||||
"""Request for message validation"""
|
"""Request for message validation"""
|
||||||
|
|
||||||
validators: list["TopicValidator"]
|
validators: list["TopicValidator"]
|
||||||
# TODO: Use a more specific type for msg_forwarder
|
msg_forwarder: ID # peer ID
|
||||||
msg_forwarder: Any # peer ID
|
msg: rpc_pb2.Message # message object
|
||||||
msg: Any # message object
|
|
||||||
result_callback: Callable[[ValidationResult, Exception | None], None]
|
result_callback: Callable[[ValidationResult, Exception | None], None]
|
||||||
|
|
||||||
|
|
||||||
@ -109,8 +111,8 @@ class ValidationThrottler:
|
|||||||
async def submit_validation(
|
async def submit_validation(
|
||||||
self,
|
self,
|
||||||
validators: list[TopicValidator],
|
validators: list[TopicValidator],
|
||||||
msg_forwarder: Any,
|
msg_forwarder: ID,
|
||||||
msg: Any,
|
msg: rpc_pb2.Message,
|
||||||
result_callback: Callable[[ValidationResult, Exception | None], None],
|
result_callback: Callable[[ValidationResult, Exception | None], None],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -211,7 +213,7 @@ class ValidationThrottler:
|
|||||||
return ValidationResult.ACCEPT
|
return ValidationResult.ACCEPT
|
||||||
|
|
||||||
async def _validate_async_validators(
|
async def _validate_async_validators(
|
||||||
self, validators: list[TopicValidator], msg_forwarder: Any, msg: Any
|
self, validators: list[TopicValidator], msg_forwarder: ID, msg: rpc_pb2.Message
|
||||||
) -> ValidationResult:
|
) -> ValidationResult:
|
||||||
"""Handle async validators with proper throttling"""
|
"""Handle async validators with proper throttling"""
|
||||||
if len(validators) == 1:
|
if len(validators) == 1:
|
||||||
@ -268,7 +270,7 @@ class ValidationThrottler:
|
|||||||
return ValidationResult.IGNORE
|
return ValidationResult.IGNORE
|
||||||
|
|
||||||
async def _validate_single_async_validator(
|
async def _validate_single_async_validator(
|
||||||
self, validator: TopicValidator, msg_forwarder: Any, msg: Any
|
self, validator: TopicValidator, msg_forwarder: ID, msg: rpc_pb2.Message
|
||||||
) -> ValidationResult:
|
) -> ValidationResult:
|
||||||
"""Validate with a single async validator"""
|
"""Validate with a single async validator"""
|
||||||
# Apply per-topic throttling
|
# Apply per-topic throttling
|
||||||
@ -286,20 +288,14 @@ class ValidationThrottler:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Apply timeout if configured
|
# Apply timeout if configured
|
||||||
result: bool | Awaitable[bool]
|
result: bool
|
||||||
if validator.timeout:
|
if validator.timeout:
|
||||||
with trio.fail_after(validator.timeout):
|
with trio.fail_after(validator.timeout):
|
||||||
func = validator.validator
|
func = cast(AsyncValidatorFn, validator.validator)
|
||||||
if inspect.iscoroutinefunction(func):
|
|
||||||
result = await func(msg_forwarder, msg)
|
|
||||||
else:
|
|
||||||
result = func(msg_forwarder, msg)
|
|
||||||
else:
|
|
||||||
func = validator.validator
|
|
||||||
if inspect.iscoroutinefunction(func):
|
|
||||||
result = await func(msg_forwarder, msg)
|
result = await func(msg_forwarder, msg)
|
||||||
else:
|
else:
|
||||||
result = func(msg_forwarder, msg)
|
func = cast(AsyncValidatorFn, validator.validator)
|
||||||
|
result = await func(msg_forwarder, msg)
|
||||||
|
|
||||||
return ValidationResult.ACCEPT if result else ValidationResult.REJECT
|
return ValidationResult.ACCEPT if result else ValidationResult.REJECT
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user