mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-09 22:50:54 +00:00
Add clean-up logics into TrioSubscriptionAPI
Register an `unsubscribe_fn` when initializing the TrioSubscriptionAPI. `unsubscribe_fn` is called when subscription is unsubscribed.
This commit is contained in:
@ -24,7 +24,7 @@ class ISubscriptionAPI(
|
||||
AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message]
|
||||
):
|
||||
@abstractmethod
|
||||
async def cancel(self) -> None:
|
||||
async def unsubscribe(self) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
@ -387,9 +388,14 @@ class Pubsub(Service, IPubsub):
|
||||
if topic_id in self.topic_ids:
|
||||
return self.subscribed_topics_receive[topic_id]
|
||||
|
||||
channels = trio.open_memory_channel[rpc_pb2.Message](math.inf)
|
||||
send_channel, receive_channel = channels
|
||||
subscription = TrioSubscriptionAPI(receive_channel)
|
||||
send_channel, receive_channel = trio.open_memory_channel[rpc_pb2.Message](
|
||||
math.inf
|
||||
)
|
||||
|
||||
subscription = TrioSubscriptionAPI(
|
||||
receive_channel,
|
||||
unsubscribe_fn=functools.partial(self.unsubscribe, topic_id),
|
||||
)
|
||||
self.subscribed_topics_send[topic_id] = send_channel
|
||||
self.subscribed_topics_receive[topic_id] = subscription
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ import trio
|
||||
|
||||
from .abc import ISubscriptionAPI
|
||||
from .pb import rpc_pb2
|
||||
from .typing import UnsubscribeFn
|
||||
|
||||
|
||||
class BaseSubscriptionAPI(ISubscriptionAPI):
|
||||
@ -18,19 +19,25 @@ class BaseSubscriptionAPI(ISubscriptionAPI):
|
||||
exc_value: "Optional[BaseException]",
|
||||
traceback: "Optional[TracebackType]",
|
||||
) -> None:
|
||||
await self.cancel()
|
||||
await self.unsubscribe()
|
||||
|
||||
|
||||
class TrioSubscriptionAPI(BaseSubscriptionAPI):
|
||||
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
|
||||
unsubscribe_fn: UnsubscribeFn
|
||||
|
||||
def __init__(
|
||||
self, receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
|
||||
self,
|
||||
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]",
|
||||
unsubscribe_fn: UnsubscribeFn,
|
||||
) -> None:
|
||||
self.receive_channel = receive_channel
|
||||
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
|
||||
self.unsubscribe_fn = unsubscribe_fn # type: ignore
|
||||
|
||||
async def cancel(self) -> None:
|
||||
await self.receive_channel.aclose()
|
||||
async def unsubscribe(self) -> None:
|
||||
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
|
||||
await self.unsubscribe_fn() # type: ignore
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]:
|
||||
return self.receive_channel.__aiter__()
|
||||
|
||||
@ -7,3 +7,5 @@ from .pb import rpc_pb2
|
||||
SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
|
||||
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
||||
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
||||
|
||||
UnsubscribeFn = Callable[[], Awaitable[None]]
|
||||
|
||||
Reference in New Issue
Block a user