Files
py-libp2p/libp2p/tools/async_service/trio_service.py
2025-01-25 15:48:39 -07:00

448 lines
13 KiB
Python

# Originally copied from https://github.com/ethereum/async-service
from __future__ import (
annotations,
)
from collections.abc import (
AsyncIterator,
Awaitable,
Coroutine,
Sequence,
)
from contextlib import (
asynccontextmanager,
)
import functools
import sys
from typing import (
Any,
Callable,
Optional,
TypeVar,
cast,
)
if sys.version_info >= (3, 11):
from builtins import (
ExceptionGroup,
)
else:
from exceptiongroup import ExceptionGroup
import trio
import trio_typing
from ._utils import (
get_task_name,
)
from .abc import (
ManagerAPI,
ServiceAPI,
TaskAPI,
TaskWithChildrenAPI,
)
from .base import (
BaseChildServiceTask,
BaseFunctionTask,
BaseManager,
)
from .exceptions import (
DaemonTaskExit,
LifecycleError,
)
from .typing import (
EXC_INFO,
AsyncFn,
)
class FunctionTask(BaseFunctionTask):
_trio_task: trio.lowlevel.Task | None = None
def __init__(
self,
name: str,
daemon: bool,
parent: TaskWithChildrenAPI | None,
async_fn: AsyncFn,
async_fn_args: Sequence[Any],
) -> None:
super().__init__(name, daemon, parent, async_fn, async_fn_args)
# We use an event to manually track when the child task is "done".
# This is because trio has no API for awaiting completion of a task.
self._done = trio.Event()
# Each task gets its own `CancelScope` which is how we can manually
# control cancellation order of the task DAG
self._cancel_scope = trio.CancelScope()
#
# Trio specific API
#
@property
def has_trio_task(self) -> bool:
return self._trio_task is not None
@property
def trio_task(self) -> trio.lowlevel.Task:
if self._trio_task is None:
raise LifecycleError("Trio task not set yet")
return self._trio_task
@trio_task.setter
def trio_task(self, value: trio.lowlevel.Task) -> None:
if self._trio_task is not None:
raise LifecycleError(f"Task already set: {self._trio_task}")
self._trio_task = value
#
# Core Task API
#
async def run(self) -> None:
self.trio_task = trio.lowlevel.current_task()
try:
with self._cancel_scope:
await self._async_fn(*self._async_fn_args)
if self.daemon:
raise DaemonTaskExit(f"Daemon task {self} exited")
while self.children:
await tuple(self.children)[0].wait_done()
finally:
self._done.set()
if self.parent is not None:
self.parent.discard_child(self)
async def cancel(self) -> None:
for task in tuple(self.children):
await task.cancel()
self._cancel_scope.cancel()
await self.wait_done()
@property
def is_done(self) -> bool:
return self._done.is_set()
async def wait_done(self) -> None:
await self._done.wait()
class ChildServiceTask(BaseChildServiceTask):
def __init__(
self,
name: str,
daemon: bool,
parent: TaskWithChildrenAPI | None,
child_service: ServiceAPI,
) -> None:
super().__init__(name, daemon, parent)
self._child_service = child_service
self.child_manager = TrioManager(child_service)
async def cancel(self) -> None:
if self.child_manager.is_started:
await self.child_manager.stop()
class TrioManager(BaseManager):
# A nursery for sub tasks and services. This nursery is cancelled if the
# service is cancelled but allowed to exit normally if the service exits.
_task_nursery: trio_typing.Nursery
def __init__(self, service: ServiceAPI) -> None:
super().__init__(service)
# events
self._started = trio.Event()
self._cancelled = trio.Event()
self._finished = trio.Event()
# locks
self._run_lock = trio.Lock()
#
# System Tasks
#
async def _handle_cancelled(self) -> None:
self.logger.debug("%s: _handle_cancelled waiting for cancellation", self)
await self._cancelled.wait()
self.logger.debug("%s: _handle_cancelled triggering task cancellation", self)
# The `_root_tasks` changes size as each task completes itself
# and removes itself from the set. For this reason we iterate over a
# copy of the set.
for task in tuple(self._root_tasks):
await task.cancel()
# This finaly cancellation of the task nursery's cancel scope ensures
# that nothing is left behind and that the service will reliably exit.
self._task_nursery.cancel_scope.cancel()
@classmethod
async def run_service(cls, service: ServiceAPI) -> None:
manager = cls(service)
await manager.run()
async def run(self) -> None:
if self._run_lock.locked():
raise LifecycleError(
"Cannot run a service with the run lock already engaged. "
"Already started?"
)
elif self.is_started:
raise LifecycleError("Cannot run a service which is already started.")
try:
async with self._run_lock:
async with trio.open_nursery() as system_nursery:
system_nursery.start_soon(self._handle_cancelled)
try:
async with trio.open_nursery() as task_nursery:
self._task_nursery = task_nursery
self._started.set()
self.run_task(self._service.run, name="run")
# This is hack to get the task stats correct. We don't want
# to count the `Service.run` method as a task. This is still
# imperfect as it will still count as a completed task when
# it finishes.
self._total_task_count = 0
# ***BLOCKING HERE***
# The code flow will block here until the background tasks
# have completed or cancellation occurs.
except Exception:
# Exceptions from any tasks spawned by our service will be
# caught by trio and raised here, so we store them to report
# together with any others we have already captured.
self._errors.append(cast(EXC_INFO, sys.exc_info()))
finally:
system_nursery.cancel_scope.cancel()
finally:
# We need this inside a finally because a trio.Cancelled exception may be
# raised here and it wouldn't be swalled by the 'except Exception' above.
self._finished.set()
self.logger.debug("%s: finished", self)
# This is outside of the finally block above because we don't want to suppress
# trio.Cancelled or ExceptionGroup exceptions coming directly from trio.
if self.did_error:
raise ExceptionGroup(
"Encountered multiple Exceptions: ",
tuple(
exc_value.with_traceback(exc_tb)
for _, exc_value, exc_tb in self._errors
if isinstance(exc_value, Exception)
),
)
#
# Event API mirror
#
@property
def is_started(self) -> bool:
return self._started.is_set()
@property
def is_cancelled(self) -> bool:
return self._cancelled.is_set()
@property
def is_finished(self) -> bool:
return self._finished.is_set()
#
# Control API
#
def cancel(self) -> None:
if not self.is_started:
raise LifecycleError("Cannot cancel as service which was never started.")
elif not self.is_running:
return
else:
self._cancelled.set()
#
# Wait API
#
async def wait_started(self) -> None:
await self._started.wait()
async def wait_finished(self) -> None:
await self._finished.wait()
def _find_parent_task(
self, trio_task: trio.lowlevel.Task
) -> TaskWithChildrenAPI | None:
"""
Find the :class:`async_service.trio.FunctionTask` instance that corresponds to
the given :class:`trio.lowlevel.Task` instance.
"""
for task in FunctionTask.iterate_tasks(*self._root_tasks):
# Any task that has not had its `trio_task` set can be safely
# skipped as those are still in the process of starting up which
# means that they cannot be the parent task since they will not
# have had a chance to schedule child tasks.
if not task.has_trio_task:
continue
if trio_task is task.trio_task:
return task
else:
# In the case that no tasks match we assume this is a new `root`
# task and return `None` as the parent.
return None
def _schedule_task(self, task: TaskAPI) -> None:
self._task_nursery.start_soon(self._run_and_manage_task, task, name=str(task))
def run_task(
self,
async_fn: Callable[..., Awaitable[Any]],
*args: Any,
daemon: bool = False,
name: str = None,
) -> None:
task = FunctionTask(
name=get_task_name(async_fn, name),
daemon=daemon,
parent=self._find_parent_task(trio.lowlevel.current_task()),
async_fn=async_fn,
async_fn_args=args,
)
self._common_run_task(task)
def run_child_service(
self, service: ServiceAPI, daemon: bool = False, name: str = None
) -> ManagerAPI:
task = ChildServiceTask(
name=get_task_name(service, name),
daemon=daemon,
parent=self._find_parent_task(trio.lowlevel.current_task()),
child_service=service,
)
self._common_run_task(task)
return task.child_manager
TFunc = TypeVar("TFunc", bound=Callable[..., Coroutine[Any, Any, Any]])
_ChannelPayload = tuple[Optional[Any], Optional[BaseException]]
async def _wait_finished(
service: ServiceAPI,
api_func: Callable[..., Any],
channel: trio.abc.SendChannel[_ChannelPayload],
) -> None:
manager = service.get_manager()
if manager.is_finished:
await channel.send(
(
None,
LifecycleError(
f"Cannot access external API {api_func}. "
f"Service {service} is not running: "
),
)
)
return
await manager.wait_finished()
await channel.send(
(
None,
LifecycleError(
f"Cannot access external API {api_func}. "
f"Service {service} is not running: "
),
)
)
async def _wait_api_fn(
self: ServiceAPI,
api_fn: Callable[..., Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
channel: trio.abc.SendChannel[_ChannelPayload],
) -> None:
try:
result = await api_fn(self, *args, **kwargs)
except Exception:
_, exc_value, exc_tb = sys.exc_info()
if exc_value is None or exc_tb is None:
raise Exception(
"This should be unreachable but acts as a type guard for mypy"
)
await channel.send((None, exc_value.with_traceback(exc_tb)))
else:
await channel.send((result, None))
def external_api(func: TFunc) -> TFunc:
@functools.wraps(func)
async def inner(self: ServiceAPI, *args: Any, **kwargs: Any) -> Any:
if not hasattr(self, "manager"):
raise LifecycleError(
f"Cannot access external API {func}. Service {self} has not been run."
)
manager = self.get_manager()
if not manager.is_running:
raise LifecycleError(
f"Cannot access external API {func}. Service {self} is not running: "
)
channels: tuple[
trio.abc.SendChannel[_ChannelPayload],
trio.abc.ReceiveChannel[_ChannelPayload],
] = trio.open_memory_channel(0)
send_channel, receive_channel = channels
async with trio.open_nursery() as nursery:
# mypy's type hints for start_soon break with this invocation.
nursery.start_soon(
_wait_api_fn, self, func, args, kwargs, send_channel # type: ignore
)
nursery.start_soon(_wait_finished, self, func, send_channel)
result, err = await receive_channel.receive()
nursery.cancel_scope.cancel()
if err is None:
return result
else:
raise err
return cast(TFunc, inner)
@asynccontextmanager
async def background_trio_service(service: ServiceAPI) -> AsyncIterator[ManagerAPI]:
"""
Run a service in the background.
The service is running within the context
block and will be properly cleaned up upon exiting the context block.
"""
async with trio.open_nursery() as nursery:
manager = TrioManager(service)
nursery.start_soon(manager.run)
await manager.wait_started()
try:
yield manager
finally:
await manager.stop()