From d9b92635c1a7d6f2800c3db44e7c92a230f18bb1 Mon Sep 17 00:00:00 2001 From: pacrob <5199899+pacrob@users.noreply.github.com> Date: Sun, 19 May 2024 14:48:03 -0600 Subject: [PATCH] drop async-service dep and copy relevant code into a local async_service tool, updated for modern handling of ExceptionGroup --- .pre-commit-config.yaml | 1 + docs/libp2p.tools.async_service.rst | 61 ++ docs/libp2p.tools.rst | 1 + libp2p/host/basic_host.py | 6 +- libp2p/network/network_interface.py | 6 +- libp2p/network/swarm.py | 6 +- libp2p/protocol_muxer/multiselect_client.py | 4 +- libp2p/pubsub/abc.py | 7 +- libp2p/pubsub/gossipsub.py | 6 +- libp2p/pubsub/pubsub.py | 6 +- libp2p/tools/async_service/__init__.py | 15 + libp2p/tools/async_service/_utils.py | 41 ++ libp2p/tools/async_service/abc.py | 257 +++++++ libp2p/tools/async_service/base.py | 378 ++++++++++ libp2p/tools/async_service/exceptions.py | 26 + libp2p/tools/async_service/stats.py | 18 + libp2p/tools/async_service/trio_service.py | 446 ++++++++++++ libp2p/tools/async_service/typing.py | 16 + libp2p/tools/factories.py | 6 +- libp2p/tools/interop/process.py | 3 +- libp2p/tools/pubsub/dummy_account_node.py | 8 +- newsfragments/467.breaking.rst | 1 + setup.py | 3 +- tests/core/network/test_notify.py | 6 +- .../protocol_muxer/test_protocol_muxer.py | 20 +- .../async_service/test_trio_based_service.py | 668 ++++++++++++++++++ .../async_service/test_trio_external_api.py | 109 +++ .../async_service/test_trio_manager_stats.py | 86 +++ 28 files changed, 2176 insertions(+), 35 deletions(-) create mode 100644 docs/libp2p.tools.async_service.rst create mode 100644 libp2p/tools/async_service/__init__.py create mode 100644 libp2p/tools/async_service/_utils.py create mode 100644 libp2p/tools/async_service/abc.py create mode 100644 libp2p/tools/async_service/base.py create mode 100644 libp2p/tools/async_service/exceptions.py create mode 100644 libp2p/tools/async_service/stats.py create mode 100644 libp2p/tools/async_service/trio_service.py create mode 100644 libp2p/tools/async_service/typing.py create mode 100644 newsfragments/467.breaking.rst create mode 100644 tests/core/tools/async_service/test_trio_based_service.py create mode 100644 tests/core/tools/async_service/test_trio_external_api.py create mode 100644 tests/core/tools/async_service/test_trio_manager_stats.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 49759c9d..f710f242 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -49,6 +49,7 @@ repos: - id: mypy additional_dependencies: - mypy-protobuf + - trio-typing exclude: 'tests/' - repo: local hooks: diff --git a/docs/libp2p.tools.async_service.rst b/docs/libp2p.tools.async_service.rst new file mode 100644 index 00000000..d57f186a --- /dev/null +++ b/docs/libp2p.tools.async_service.rst @@ -0,0 +1,61 @@ +libp2p.tools.async\_service package +=================================== + +Submodules +---------- + +libp2p.tools.async\_service.abc module +-------------------------------------- + +.. automodule:: libp2p.tools.async_service.abc + :members: + :undoc-members: + :show-inheritance: + +libp2p.tools.async\_service.base module +--------------------------------------- + +.. automodule:: libp2p.tools.async_service.base + :members: + :undoc-members: + :show-inheritance: + +libp2p.tools.async\_service.exceptions module +--------------------------------------------- + +.. automodule:: libp2p.tools.async_service.exceptions + :members: + :undoc-members: + :show-inheritance: + +libp2p.tools.async\_service.stats module +---------------------------------------- + +.. automodule:: libp2p.tools.async_service.stats + :members: + :undoc-members: + :show-inheritance: + +libp2p.tools.async\_service.trio\_service module +------------------------------------------------ + +.. automodule:: libp2p.tools.async_service.trio_service + :members: + :undoc-members: + :show-inheritance: + +libp2p.tools.async\_service.typing module +----------------------------------------- + +.. automodule:: libp2p.tools.async_service.typing + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: libp2p.tools.async_service + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/libp2p.tools.rst b/docs/libp2p.tools.rst index 7230a537..b412db74 100644 --- a/docs/libp2p.tools.rst +++ b/docs/libp2p.tools.rst @@ -7,6 +7,7 @@ Subpackages .. toctree:: :maxdepth: 4 + libp2p.tools.async_service libp2p.tools.pubsub Submodules diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 18054290..f6157748 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -9,9 +9,6 @@ from typing import ( Sequence, ) -from async_service import ( - background_trio_service, -) import multiaddr from libp2p.crypto.keys import ( @@ -52,6 +49,9 @@ from libp2p.protocol_muxer.multiselect_client import ( from libp2p.protocol_muxer.multiselect_communicator import ( MultiselectCommunicator, ) +from libp2p.tools.async_service import ( + background_trio_service, +) from libp2p.typing import ( StreamHandlerFn, TProtocol, diff --git a/libp2p/network/network_interface.py b/libp2p/network/network_interface.py index db640465..676ac05f 100644 --- a/libp2p/network/network_interface.py +++ b/libp2p/network/network_interface.py @@ -8,9 +8,6 @@ from typing import ( Sequence, ) -from async_service import ( - ServiceAPI, -) from multiaddr import ( Multiaddr, ) @@ -24,6 +21,9 @@ from libp2p.peer.id import ( from libp2p.peer.peerstore_interface import ( IPeerStore, ) +from libp2p.tools.async_service import ( + ServiceAPI, +) from libp2p.transport.listener_interface import ( IListener, ) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 614bc3c2..0be40007 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -5,9 +5,6 @@ from typing import ( Optional, ) -from async_service import ( - Service, -) from multiaddr import ( Multiaddr, ) @@ -31,6 +28,9 @@ from libp2p.peer.peerstore_interface import ( from libp2p.stream_muxer.abc import ( IMuxedConn, ) +from libp2p.tools.async_service import ( + Service, +) from libp2p.transport.exceptions import ( MuxerUpgradeFailure, OpenConnectionError, diff --git a/libp2p/protocol_muxer/multiselect_client.py b/libp2p/protocol_muxer/multiselect_client.py index b5a7263f..06be47c3 100644 --- a/libp2p/protocol_muxer/multiselect_client.py +++ b/libp2p/protocol_muxer/multiselect_client.py @@ -32,7 +32,7 @@ class MultiselectClient(IMultiselectClient): Ensure that the client and multiselect are both using the same multiselect protocol. - :param stream: stream to communicate with multiselect over + :param communicator: communicator to use to communicate with counterparty :raise MultiselectClientError: raised when handshake failed """ try: @@ -57,7 +57,7 @@ class MultiselectClient(IMultiselectClient): protocol that multiselect agrees on (i.e. that multiselect selects) :param protocol: protocol to select - :param stream: stream to communicate with multiselect over + :param communicator: communicator to use to communicate with counterparty :return: selected protocol :raise MultiselectClientError: raised when protocol negotiation failed """ diff --git a/libp2p/pubsub/abc.py b/libp2p/pubsub/abc.py index aec6f8a4..263f8c42 100644 --- a/libp2p/pubsub/abc.py +++ b/libp2p/pubsub/abc.py @@ -11,13 +11,12 @@ from typing import ( Tuple, ) -from async_service import ( - ServiceAPI, -) - from libp2p.peer.id import ( ID, ) +from libp2p.tools.async_service import ( + ServiceAPI, +) from libp2p.typing import ( TProtocol, ) diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index 77cbd6f9..cabed25b 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -17,9 +17,6 @@ from typing import ( Tuple, ) -from async_service import ( - Service, -) import trio from libp2p.network.stream.exceptions import ( @@ -31,6 +28,9 @@ from libp2p.peer.id import ( from libp2p.pubsub import ( floodsub, ) +from libp2p.tools.async_service import ( + Service, +) from libp2p.typing import ( TProtocol, ) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index dd7bd52f..54485eb1 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -15,9 +15,6 @@ from typing import ( cast, ) -from async_service import ( - Service, -) import base58 from lru import ( LRU, @@ -51,6 +48,9 @@ from libp2p.network.stream.net_stream_interface import ( from libp2p.peer.id import ( ID, ) +from libp2p.tools.async_service import ( + Service, +) from libp2p.typing import ( TProtocol, ) diff --git a/libp2p/tools/async_service/__init__.py b/libp2p/tools/async_service/__init__.py new file mode 100644 index 00000000..5c42e135 --- /dev/null +++ b/libp2p/tools/async_service/__init__.py @@ -0,0 +1,15 @@ +from .abc import ( + ServiceAPI, +) +from .base import ( + Service, + as_service, +) +from .exceptions import ( + DaemonTaskExit, + LifecycleError, +) +from .trio_service import ( + TrioManager, + background_trio_service, +) diff --git a/libp2p/tools/async_service/_utils.py b/libp2p/tools/async_service/_utils.py new file mode 100644 index 00000000..6754e827 --- /dev/null +++ b/libp2p/tools/async_service/_utils.py @@ -0,0 +1,41 @@ +# Copied from https://github.com/ethereum/async-service + +import os +from typing import ( + Any, +) + + +def get_task_name(value: Any, explicit_name: str = None) -> str: + # inline import to ensure `_utils` is always importable from the rest of + # the module. + from .abc import ( # noqa: F401 + ServiceAPI, + ) + + if explicit_name is not None: + # if an explicit name was provided, just return that. + return explicit_name + elif isinstance(value, ServiceAPI): + # `Service` instance naming rules: + # + # 1. __str__ **if** the class implements a custom __str__ method + # 2. __repr__ **if** the class implements a custom __repr__ method + # 3. The `Service` class name. + value_cls = type(value) + if value_cls.__str__ is not object.__str__: + return str(value) + if value_cls.__repr__ is not object.__repr__: + return repr(value) + else: + return value.__class__.__name__ + else: + try: + # Prefer the name of the function if it has one + return str(value.__name__) # mypy doesn't know __name__ is a `str` + except AttributeError: + return repr(value) + + +def is_verbose_logging_enabled() -> bool: + return bool(os.environ.get("ASYNC_SERVICE_VERBOSE_LOG", False)) diff --git a/libp2p/tools/async_service/abc.py b/libp2p/tools/async_service/abc.py new file mode 100644 index 00000000..5b0a734b --- /dev/null +++ b/libp2p/tools/async_service/abc.py @@ -0,0 +1,257 @@ +# Copied from https://github.com/ethereum/async-service + +from abc import ( + ABC, + abstractmethod, +) +from typing import ( + Any, + Hashable, + Optional, + Set, +) + +import trio_typing + +from .stats import ( + Stats, +) +from .typing import ( + AsyncFn, +) + + +class TaskAPI(Hashable): + name: str + daemon: bool + parent: Optional["TaskWithChildrenAPI"] + + @abstractmethod + async def run(self) -> None: + ... + + @abstractmethod + async def cancel(self) -> None: + ... + + @property + @abstractmethod + def is_done(self) -> bool: + ... + + @abstractmethod + async def wait_done(self) -> None: + ... + + +class TaskWithChildrenAPI(TaskAPI): + children: Set[TaskAPI] + + @abstractmethod + def add_child(self, child: TaskAPI) -> None: + ... + + @abstractmethod + def discard_child(self, child: TaskAPI) -> None: + ... + + +class ServiceAPI(ABC): + _manager: "InternalManagerAPI" + + @abstractmethod + def get_manager(self) -> "ManagerAPI": + """ + External retrieval of the manager for this service. + + Will raise a :class:`~async_service.exceptions.LifecycleError` if the + service does not yet have a `manager` assigned to it. + """ + ... + + @abstractmethod + async def run(self) -> None: + """ + Primary entry point for all service logic. + + .. note:: This method should **not** be directly invoked by user code. + + Services may be run using the following approaches. + + .. code-block: python + + # 1. run the service in the background using a context manager + async with run_service(service) as manager: + # service runs inside context block + ... + # service cancels and stops when context exits + # service will have fully stopped + + # 2. run the service blocking until completion + await Manager.run_service(service) + + # 3. create manager and then run service blocking until completion + manager = Manager(service) + await manager.run() + """ + ... + + +class ManagerAPI(ABC): + @property + @abstractmethod + def is_started(self) -> bool: + """ + Return boolean indicating if the underlying service has been started. + """ + ... + + @property + @abstractmethod + def is_running(self) -> bool: + """ + Return boolean indicating if the underlying service is actively + running. + + A service is considered running if it has been started and + has not yet been stopped. + """ + ... + + @property + @abstractmethod + def is_cancelled(self) -> bool: + """ + Return boolean indicating if the underlying service has been cancelled. + + This can occure externally via the `cancel()` method or internally due + to a task crash or a crash of the actual :meth:`ServiceAPI.run` method. + """ + ... + + @property + @abstractmethod + def is_finished(self) -> bool: + """ + Return boolean indicating if the underlying service is stopped. + + A stopped service will have completed all of the background tasks. + """ + ... + + @property + @abstractmethod + def did_error(self) -> bool: + """ + Return boolean indicating if the underlying service threw an exception. + """ + ... + + @abstractmethod + def cancel(self) -> None: + """ + Trigger cancellation of the service. + """ + ... + + @abstractmethod + async def stop(self) -> None: + """ + Trigger cancellation of the service and wait for it to finish. + """ + ... + + @abstractmethod + async def wait_started(self) -> None: + """ + Wait until the service is started. + """ + ... + + @abstractmethod + async def wait_finished(self) -> None: + """ + Wait until the service is stopped. + """ + ... + + @classmethod + @abstractmethod + async def run_service(cls, service: ServiceAPI) -> None: + """ + Run a service + """ + ... + + @abstractmethod + async def run(self) -> None: + """ + Run a service + """ + ... + + @property + @abstractmethod + def stats(self) -> Stats: + """ + Return a stats object with details about the service. + """ + ... + + +class InternalManagerAPI(ManagerAPI): + """ + Defines the API that the `Service.manager` property exposes. + + The InternalManagerAPI / ManagerAPI distinction is in place to ensure that + external callers to a service do not try to use the task scheduling + functionality as it is only designed to be used internally. + """ + + @trio_typing.takes_callable_and_args + @abstractmethod + def run_task( + self, async_fn: AsyncFn, *args: Any, daemon: bool = False, name: str = None + ) -> None: + """ + Run a task in the background. If the function throws an exception it + will trigger the service to be cancelled and be propogated. + + If `daemon == True` then the the task is expected to run indefinitely + and will trigger cancellation if the task finishes. + """ + ... + + @trio_typing.takes_callable_and_args + @abstractmethod + def run_daemon_task(self, async_fn: AsyncFn, *args: Any, name: str = None) -> None: + """ + Run a daemon task in the background. + + Equivalent to `run_task(..., daemon=True)`. + """ + ... + + @abstractmethod + def run_child_service( + self, service: ServiceAPI, daemon: bool = False, name: str = None + ) -> "ManagerAPI": + """ + Run a service in the background. If the function throws an exception it + will trigger the parent service to be cancelled and be propogated. + + If `daemon == True` then the the service is expected to run indefinitely + and will trigger cancellation if the service finishes. + """ + ... + + @abstractmethod + def run_daemon_child_service( + self, service: ServiceAPI, name: str = None + ) -> "ManagerAPI": + """ + Run a daemon service in the background. + + Equivalent to `run_child_service(..., daemon=True)`. + """ + ... diff --git a/libp2p/tools/async_service/base.py b/libp2p/tools/async_service/base.py new file mode 100644 index 00000000..9c3aee2d --- /dev/null +++ b/libp2p/tools/async_service/base.py @@ -0,0 +1,378 @@ +# Copied from https://github.com/ethereum/async-service + +from abc import ( + abstractmethod, +) +import asyncio +from collections import ( + Counter, +) +import logging +import sys +from typing import ( + Any, + Awaitable, + Callable, + Iterable, + List, + Optional, + Sequence, + Set, + Type, + TypeVar, + cast, +) +import uuid + +from ._utils import ( + is_verbose_logging_enabled, +) +from .abc import ( + InternalManagerAPI, + ManagerAPI, + ServiceAPI, + TaskAPI, + TaskWithChildrenAPI, +) +from .exceptions import ( + DaemonTaskExit, + LifecycleError, + TooManyChildrenException, +) +from .stats import ( + Stats, + TaskStats, +) +from .typing import ( + EXC_INFO, + AsyncFn, +) + +MAX_CHILDREN_TASKS = 1000 + + +class Service(ServiceAPI): + def __str__(self) -> str: + return self.__class__.__name__ + + @property + def manager(self) -> "InternalManagerAPI": + """ + Expose the manager as a property here intead of + :class:`async_service.abc.ServiceAPI` to ensure that anyone using + proper type hints will not have access to this property since it isn't + part of that API, while still allowing all subclasses of the + :class:`async_service.base.Service` to access this property directly. + """ + return self._manager + + def get_manager(self) -> ManagerAPI: + try: + return self._manager + except AttributeError: + raise LifecycleError( + "Service does not have a manager assigned to it. Are you sure " + "it is running?" + ) + + +LogicFnType = Callable[..., Awaitable[Any]] + + +def as_service(service_fn: LogicFnType) -> Type[ServiceAPI]: + """ + Create a service out of a simple function + """ + + class _Service(Service): + def __init__(self, *args: Any, **kwargs: Any): + self._args = args + self._kwargs = kwargs + + async def run(self) -> None: + await service_fn(self.manager, *self._args, **self._kwargs) + + _Service.__name__ = service_fn.__name__ + _Service.__doc__ = service_fn.__doc__ + return _Service + + +class BaseTask(TaskAPI): + def __init__( + self, name: str, daemon: bool, parent: Optional[TaskWithChildrenAPI] + ) -> None: + # meta + self.name = name + self.daemon = daemon + + # parent task + self.parent = parent + + # For hashable interface. + self._id = uuid.uuid4() + + def __hash__(self) -> int: + return hash(self._id) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, TaskAPI): + return hash(self) == hash(other) + else: + return False + + def __str__(self) -> str: + return f"{self.name}[daemon={self.daemon}]" + + +class BaseTaskWithChildren(BaseTask, TaskWithChildrenAPI): + def __init__( + self, name: str, daemon: bool, parent: Optional[TaskWithChildrenAPI] + ) -> None: + super().__init__(name, daemon, parent) + self.children = set() + + def add_child(self, child: TaskAPI) -> None: + self.children.add(child) + + def discard_child(self, child: TaskAPI) -> None: + self.children.discard(child) + + +T = TypeVar("T", bound="BaseFunctionTask") + + +class BaseFunctionTask(BaseTaskWithChildren): + @classmethod + def iterate_tasks(cls: Type[T], *tasks: TaskAPI) -> Iterable[T]: + for task in tasks: + if isinstance(task, cls): + yield task + else: + continue + + yield from cls.iterate_tasks( + *( + child_task + for child_task in task.children + if isinstance(child_task, cls) + ) + ) + + def __init__( + self, + name: str, + daemon: bool, + parent: Optional[TaskWithChildrenAPI], + async_fn: AsyncFn, + async_fn_args: Sequence[Any], + ) -> None: + super().__init__(name, daemon, parent) + + self._async_fn = async_fn + self._async_fn_args = async_fn_args + + +class BaseChildServiceTask(BaseTask): + _child_service: ServiceAPI + child_manager: ManagerAPI + + async def run(self) -> None: + if self.child_manager.is_started: + raise LifecycleError( + f"Child service {self._child_service} has already been started" + ) + + try: + await self.child_manager.run() + + if self.daemon: + raise DaemonTaskExit(f"Daemon task {self} exited") + finally: + if self.parent is not None: + self.parent.discard_child(self) + + @property + def is_done(self) -> bool: + return self.child_manager.is_finished + + async def wait_done(self) -> None: + if self.child_manager.is_started: + await self.child_manager.wait_finished() + + +class BaseManager(InternalManagerAPI): + logger = logging.getLogger("async_service.Manager") + _verbose = is_verbose_logging_enabled() + + _service: ServiceAPI + + _errors: List[EXC_INFO] + + def __init__(self, service: ServiceAPI) -> None: + if hasattr(service, "_manager"): + raise LifecycleError("Service already has a manager.") + else: + service._manager = self + + self._service = service + + # errors + self._errors = [] + + # tasks + self._root_tasks: Set[TaskAPI] = set() + + # stats + self._total_task_count = 0 + self._done_task_count = 0 + + def __str__(self) -> str: + status_flags = "".join( + ( + "S" if self.is_started else "s", + "R" if self.is_running else "r", + "C" if self.is_cancelled else "c", + "F" if self.is_finished else "f", + "E" if self.did_error else "e", + ) + ) + return f"" + + # + # Event API mirror + # + @property + def is_running(self) -> bool: + return self.is_started and not self.is_finished + + @property + def did_error(self) -> bool: + return len(self._errors) > 0 + + # + # Control API + # + async def stop(self) -> None: + self.cancel() + await self.wait_finished() + + # + # Wait API + # + def run_daemon_task( + self, async_fn: Callable[..., Awaitable[Any]], *args: Any, name: str = None + ) -> None: + self.run_task(async_fn, *args, daemon=True, name=name) + + def run_daemon_child_service( + self, service: ServiceAPI, name: str = None + ) -> ManagerAPI: + return self.run_child_service(service, daemon=True, name=name) + + @property + def stats(self) -> Stats: + # The `max` call here ensures that if this is called prior to the + # `Service.run` method starting we don't return `-1` + total_count = max(0, self._total_task_count) + + # Since we track `Service.run` as a task, the `min` call here ensures + # that when the service is fully done that we don't represent the + # `Service.run` method in this count. + finished_count = min(total_count, self._done_task_count) + return Stats( + tasks=TaskStats(total_count=total_count, finished_count=finished_count) + ) + + # + # Task Management + # + @abstractmethod + def _schedule_task(self, task: TaskAPI) -> None: + ... + + def _common_run_task(self, task: TaskAPI) -> None: + if not self.is_running: + raise LifecycleError( + "Tasks may not be scheduled if the service is not running" + ) + + if self.is_running and self.is_cancelled: + self.logger.debug( + "%s: service is being cancelled. Not running task %s", self, task + ) + return + + self._add_child_task(task.parent, task) + self._total_task_count += 1 + + self._schedule_task(task) + + def _add_child_task( + self, parent: Optional[TaskWithChildrenAPI], task: TaskAPI + ) -> None: + if parent is None: + all_children = self._root_tasks + else: + all_children = parent.children + + if len(all_children) > MAX_CHILDREN_TASKS: + task_counter = Counter(map(str, all_children)) + raise TooManyChildrenException( + f"Tried to add more than {MAX_CHILDREN_TASKS} child tasks." + f" Most common tasks: {task_counter.most_common(10)}" + ) + + if parent is None: + if self._verbose: + self.logger.debug("%s: running root task %s", self, task) + self._root_tasks.add(task) + else: + if self._verbose: + self.logger.debug("%s: %s running child task %s", self, parent, task) + parent.add_child(task) + + async def _run_and_manage_task(self, task: TaskAPI) -> None: + if self._verbose: + self.logger.debug("%s: task %s running", self, task) + + try: + try: + await task.run() + except DaemonTaskExit: + if self.is_cancelled: + pass + else: + raise + finally: + if isinstance(task, TaskWithChildrenAPI): + new_parent = task.parent + for child in task.children: + child.parent = new_parent + self._add_child_task(new_parent, child) + self.logger.debug( + "%s left a child task (%s) behind, reassigning it to %s", + task, + child, + new_parent or "root", + ) + except asyncio.CancelledError: + self.logger.debug("%s: task %s raised CancelledError.", self, task) + raise + except Exception as err: + self.logger.error( + "%s: task %s exited with error: %s", + self, + task, + err, + # Only show stacktrace if this is **not** a DaemonTaskExit error + exc_info=not isinstance(err, DaemonTaskExit), + ) + self._errors.append(cast(EXC_INFO, sys.exc_info())) + self.cancel() + else: + if task.parent is None: + self._root_tasks.remove(task) + if self._verbose: + self.logger.debug("%s: task %s exited cleanly.", self, task) + finally: + self._done_task_count += 1 diff --git a/libp2p/tools/async_service/exceptions.py b/libp2p/tools/async_service/exceptions.py new file mode 100644 index 00000000..ccb13298 --- /dev/null +++ b/libp2p/tools/async_service/exceptions.py @@ -0,0 +1,26 @@ +# Copied from https://github.com/ethereum/async-service + + +class ServiceException(Exception): + """ + Base class for Service exceptions + """ + + +class LifecycleError(ServiceException): + """ + Raised when an action would violate the service lifecycle rules. + """ + + +class DaemonTaskExit(ServiceException): + """ + Raised when an action would violate the service lifecycle rules. + """ + + +class TooManyChildrenException(ServiceException): + """ + Raised when a service adds too many children. It is a sign of task leakage + that needs to be prevented. + """ diff --git a/libp2p/tools/async_service/stats.py b/libp2p/tools/async_service/stats.py new file mode 100644 index 00000000..4f8b8fab --- /dev/null +++ b/libp2p/tools/async_service/stats.py @@ -0,0 +1,18 @@ +# Copied from https://github.com/ethereum/async-service + +from typing import ( + NamedTuple, +) + + +class TaskStats(NamedTuple): + total_count: int + finished_count: int + + @property + def pending_count(self) -> int: + return self.total_count - self.finished_count + + +class Stats(NamedTuple): + tasks: TaskStats diff --git a/libp2p/tools/async_service/trio_service.py b/libp2p/tools/async_service/trio_service.py new file mode 100644 index 00000000..db208947 --- /dev/null +++ b/libp2p/tools/async_service/trio_service.py @@ -0,0 +1,446 @@ +# Originally copied from https://github.com/ethereum/async-service +from __future__ import ( + annotations, +) + +from contextlib import ( + asynccontextmanager, +) +import functools +import sys +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Coroutine, + Optional, + Sequence, + Tuple, + TypeVar, + cast, +) + +if sys.version_info >= (3, 11): + from builtins import ( + ExceptionGroup, + ) +else: + from exceptiongroup import ExceptionGroup + +import trio +import trio_typing + +from ._utils import ( + get_task_name, +) +from .abc import ( + ManagerAPI, + ServiceAPI, + TaskAPI, + TaskWithChildrenAPI, +) +from .base import ( + BaseChildServiceTask, + BaseFunctionTask, + BaseManager, +) +from .exceptions import ( + DaemonTaskExit, + LifecycleError, +) +from .typing import ( + EXC_INFO, + AsyncFn, +) + + +class FunctionTask(BaseFunctionTask): + _trio_task: trio.lowlevel.Task | None = None + + def __init__( + self, + name: str, + daemon: bool, + parent: TaskWithChildrenAPI | None, + async_fn: AsyncFn, + async_fn_args: Sequence[Any], + ) -> None: + super().__init__(name, daemon, parent, async_fn, async_fn_args) + + # We use an event to manually track when the child task is "done". + # This is because trio has no API for awaiting completion of a task. + self._done = trio.Event() + + # Each task gets its own `CancelScope` which is how we can manually + # control cancellation order of the task DAG + self._cancel_scope = trio.CancelScope() + + # + # Trio specific API + # + @property + def has_trio_task(self) -> bool: + return self._trio_task is not None + + @property + def trio_task(self) -> trio.lowlevel.Task: + if self._trio_task is None: + raise LifecycleError("Trio task not set yet") + return self._trio_task + + @trio_task.setter + def trio_task(self, value: trio.lowlevel.Task) -> None: + if self._trio_task is not None: + raise LifecycleError(f"Task already set: {self._trio_task}") + self._trio_task = value + + # + # Core Task API + # + async def run(self) -> None: + self.trio_task = trio.lowlevel.current_task() + + try: + with self._cancel_scope: + await self._async_fn(*self._async_fn_args) + if self.daemon: + raise DaemonTaskExit(f"Daemon task {self} exited") + + while self.children: + await tuple(self.children)[0].wait_done() + finally: + self._done.set() + if self.parent is not None: + self.parent.discard_child(self) + + async def cancel(self) -> None: + for task in tuple(self.children): + await task.cancel() + self._cancel_scope.cancel() + await self.wait_done() + + @property + def is_done(self) -> bool: + return self._done.is_set() + + async def wait_done(self) -> None: + await self._done.wait() + + +class ChildServiceTask(BaseChildServiceTask): + def __init__( + self, + name: str, + daemon: bool, + parent: TaskWithChildrenAPI | None, + child_service: ServiceAPI, + ) -> None: + super().__init__(name, daemon, parent) + + self._child_service = child_service + self.child_manager = TrioManager(child_service) + + async def cancel(self) -> None: + if self.child_manager.is_started: + await self.child_manager.stop() + + +class TrioManager(BaseManager): + # A nursery for sub tasks and services. This nursery is cancelled if the + # service is cancelled but allowed to exit normally if the service exits. + _task_nursery: trio_typing.Nursery + + def __init__(self, service: ServiceAPI) -> None: + super().__init__(service) + + # events + self._started = trio.Event() + self._cancelled = trio.Event() + self._finished = trio.Event() + + # locks + self._run_lock = trio.Lock() + + # + # System Tasks + # + async def _handle_cancelled(self) -> None: + self.logger.debug("%s: _handle_cancelled waiting for cancellation", self) + await self._cancelled.wait() + self.logger.debug("%s: _handle_cancelled triggering task cancellation", self) + + # The `_root_tasks` changes size as each task completes itself + # and removes itself from the set. For this reason we iterate over a + # copy of the set. + for task in tuple(self._root_tasks): + await task.cancel() + + # This finaly cancellation of the task nursery's cancel scope ensures + # that nothing is left behind and that the service will reliably exit. + self._task_nursery.cancel_scope.cancel() + + @classmethod + async def run_service(cls, service: ServiceAPI) -> None: + manager = cls(service) + await manager.run() + + async def run(self) -> None: + if self._run_lock.locked(): + raise LifecycleError( + "Cannot run a service with the run lock already engaged. " + "Already started?" + ) + elif self.is_started: + raise LifecycleError("Cannot run a service which is already started.") + + try: + async with self._run_lock: + async with trio.open_nursery() as system_nursery: + system_nursery.start_soon(self._handle_cancelled) + + try: + async with trio.open_nursery() as task_nursery: + self._task_nursery = task_nursery + + self._started.set() + + self.run_task(self._service.run, name="run") + + # This is hack to get the task stats correct. We don't want + # to count the `Service.run` method as a task. This is still + # imperfect as it will still count as a completed task when + # it finishes. + self._total_task_count = 0 + + # ***BLOCKING HERE*** + # The code flow will block here until the background tasks + # have completed or cancellation occurs. + except Exception: + # Exceptions from any tasks spawned by our service will be + # caught by trio and raised here, so we store them to report + # together with any others we have already captured. + self._errors.append(cast(EXC_INFO, sys.exc_info())) + finally: + system_nursery.cancel_scope.cancel() + + finally: + # We need this inside a finally because a trio.Cancelled exception may be + # raised here and it wouldn't be swalled by the 'except Exception' above. + self._finished.set() + self.logger.debug("%s: finished", self) + + # This is outside of the finally block above because we don't want to suppress + # trio.Cancelled or ExceptionGroup exceptions coming directly from trio. + if self.did_error: + raise ExceptionGroup( + "Encountered multiple Exceptions: ", + tuple( + exc_value.with_traceback(exc_tb) + for _, exc_value, exc_tb in self._errors + if isinstance(exc_value, Exception) + ), + ) + + # + # Event API mirror + # + @property + def is_started(self) -> bool: + return self._started.is_set() + + @property + def is_cancelled(self) -> bool: + return self._cancelled.is_set() + + @property + def is_finished(self) -> bool: + return self._finished.is_set() + + # + # Control API + # + def cancel(self) -> None: + if not self.is_started: + raise LifecycleError("Cannot cancel as service which was never started.") + elif not self.is_running: + return + else: + self._cancelled.set() + + # + # Wait API + # + async def wait_started(self) -> None: + await self._started.wait() + + async def wait_finished(self) -> None: + await self._finished.wait() + + def _find_parent_task( + self, trio_task: trio.lowlevel.Task + ) -> TaskWithChildrenAPI | None: + """ + Find the :class:`async_service.trio.FunctionTask` instance that corresponds to + the given :class:`trio.lowlevel.Task` instance. + """ + for task in FunctionTask.iterate_tasks(*self._root_tasks): + # Any task that has not had its `trio_task` set can be safely + # skipped as those are still in the process of starting up which + # means that they cannot be the parent task since they will not + # have had a chance to schedule child tasks. + if not task.has_trio_task: + continue + + if trio_task is task.trio_task: + return task + + else: + # In the case that no tasks match we assume this is a new `root` + # task and return `None` as the parent. + return None + + def _schedule_task(self, task: TaskAPI) -> None: + self._task_nursery.start_soon(self._run_and_manage_task, task, name=str(task)) + + def run_task( + self, + async_fn: Callable[..., Awaitable[Any]], + *args: Any, + daemon: bool = False, + name: str = None, + ) -> None: + task = FunctionTask( + name=get_task_name(async_fn, name), + daemon=daemon, + parent=self._find_parent_task(trio.lowlevel.current_task()), + async_fn=async_fn, + async_fn_args=args, + ) + + self._common_run_task(task) + + def run_child_service( + self, service: ServiceAPI, daemon: bool = False, name: str = None + ) -> ManagerAPI: + task = ChildServiceTask( + name=get_task_name(service, name), + daemon=daemon, + parent=self._find_parent_task(trio.lowlevel.current_task()), + child_service=service, + ) + + self._common_run_task(task) + return task.child_manager + + +TFunc = TypeVar("TFunc", bound=Callable[..., Coroutine[Any, Any, Any]]) + + +_ChannelPayload = Tuple[Optional[Any], Optional[BaseException]] + + +async def _wait_finished( + service: ServiceAPI, + api_func: Callable[..., Any], + channel: trio.abc.SendChannel[_ChannelPayload], +) -> None: + manager = service.get_manager() + + if manager.is_finished: + await channel.send( + ( + None, + LifecycleError( + f"Cannot access external API {api_func}. " + f"Service {service} is not running: " + ), + ) + ) + return + + await manager.wait_finished() + await channel.send( + ( + None, + LifecycleError( + f"Cannot access external API {api_func}. " + f"Service {service} is not running: " + ), + ) + ) + + +async def _wait_api_fn( + self: ServiceAPI, + api_fn: Callable[..., Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + channel: trio.abc.SendChannel[_ChannelPayload], +) -> None: + try: + result = await api_fn(self, *args, **kwargs) + except Exception: + _, exc_value, exc_tb = sys.exc_info() + if exc_value is None or exc_tb is None: + raise Exception( + "This should be unreachable but acts as a type guard for mypy" + ) + await channel.send((None, exc_value.with_traceback(exc_tb))) + else: + await channel.send((result, None)) + + +def external_api(func: TFunc) -> TFunc: + @functools.wraps(func) + async def inner(self: ServiceAPI, *args: Any, **kwargs: Any) -> Any: + if not hasattr(self, "manager"): + raise LifecycleError( + f"Cannot access external API {func}. Service {self} has not been run." + ) + + manager = self.get_manager() + + if not manager.is_running: + raise LifecycleError( + f"Cannot access external API {func}. Service {self} is not running: " + ) + + channels: tuple[ + trio.abc.SendChannel[_ChannelPayload], + trio.abc.ReceiveChannel[_ChannelPayload], + ] = trio.open_memory_channel(0) + send_channel, receive_channel = channels + + async with trio.open_nursery() as nursery: + # mypy's type hints for start_soon break with this invocation. + nursery.start_soon( + _wait_api_fn, self, func, args, kwargs, send_channel # type: ignore + ) + nursery.start_soon(_wait_finished, self, func, send_channel) + result, err = await receive_channel.receive() + nursery.cancel_scope.cancel() + if err is None: + return result + else: + raise err + + return cast(TFunc, inner) + + +@asynccontextmanager +async def background_trio_service(service: ServiceAPI) -> AsyncIterator[ManagerAPI]: + """ + Run a service in the background. + + The service is running within the context + block and will be properly cleaned up upon exiting the context block. + """ + async with trio.open_nursery() as nursery: + manager = TrioManager(service) + nursery.start_soon(manager.run) + await manager.wait_started() + try: + yield manager + finally: + await manager.stop() diff --git a/libp2p/tools/async_service/typing.py b/libp2p/tools/async_service/typing.py new file mode 100644 index 00000000..f55398ff --- /dev/null +++ b/libp2p/tools/async_service/typing.py @@ -0,0 +1,16 @@ +# Copied from https://github.com/ethereum/async-service + +from types import ( + TracebackType, +) +from typing import ( + Any, + Awaitable, + Callable, + Tuple, + Type, +) + +EXC_INFO = Tuple[Type[BaseException], BaseException, TracebackType] + +AsyncFn = Callable[..., Awaitable[Any]] diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 42f81646..326b1fca 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -15,9 +15,6 @@ from typing import ( from async_exit_stack import ( AsyncExitStack, ) -from async_service import ( - background_trio_service, -) import factory from multiaddr import ( Multiaddr, @@ -111,6 +108,9 @@ from libp2p.stream_muxer.mplex.mplex import ( from libp2p.stream_muxer.mplex.mplex_stream import ( MplexStream, ) +from libp2p.tools.async_service import ( + background_trio_service, +) from libp2p.tools.constants import ( GOSSIPSUB_PARAMS, ) diff --git a/libp2p/tools/interop/process.py b/libp2p/tools/interop/process.py index f6e56130..e49ba1d8 100644 --- a/libp2p/tools/interop/process.py +++ b/libp2p/tools/interop/process.py @@ -56,7 +56,8 @@ class BaseInteractiveProcess(AbstractInterativeProcess): async def start(self) -> None: if self.proc is not None: return - self.proc = await trio.open_process( + # mypy says that `open_process` is not an attribute of trio, suggests run_process instead. # noqa: E501 + self.proc = await trio.open_process( # type: ignore[attr-defined] [self.cmd] + self.args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, # Redirect stderr to stdout, which makes parsing easier # noqa: E501 diff --git a/libp2p/tools/pubsub/dummy_account_node.py b/libp2p/tools/pubsub/dummy_account_node.py index ec12ade2..0a80a27f 100644 --- a/libp2p/tools/pubsub/dummy_account_node.py +++ b/libp2p/tools/pubsub/dummy_account_node.py @@ -10,10 +10,6 @@ from typing import ( from async_exit_stack import ( AsyncExitStack, ) -from async_service import ( - Service, - background_trio_service, -) from libp2p.host.host_interface import ( IHost, @@ -21,6 +17,10 @@ from libp2p.host.host_interface import ( from libp2p.pubsub.pubsub import ( Pubsub, ) +from libp2p.tools.async_service import ( + Service, + background_trio_service, +) from libp2p.tools.factories import ( PubsubFactory, ) diff --git a/newsfragments/467.breaking.rst b/newsfragments/467.breaking.rst new file mode 100644 index 00000000..17fc5099 --- /dev/null +++ b/newsfragments/467.breaking.rst @@ -0,0 +1 @@ +Drop dep for unmaintained ``async-service`` and copy relevant functions into a local tool of the same name diff --git a/setup.py b/setup.py index 62f33ab3..cf35f32e 100644 --- a/setup.py +++ b/setup.py @@ -62,9 +62,10 @@ install_requires = [ "coincurve>=10.0.0", "pynacl==1.3.0", "trio>=0.15.0", - "async-service>=0.1.0a6", "async-exit-stack==1.0.1", "noiseprotocol>=0.3.0", + "trio-typing>=0.0.4", + "exceptiongroup>=1.2.0; python_version < '3.11'", # added during debugging "anyio", "p2pclient", diff --git a/tests/core/network/test_notify.py b/tests/core/network/test_notify.py index b01a34a3..b9b7ed8e 100644 --- a/tests/core/network/test_notify.py +++ b/tests/core/network/test_notify.py @@ -10,15 +10,15 @@ features are implemented in swarm """ import enum -from async_service import ( - background_trio_service, -) import pytest import trio from libp2p.network.notifee_interface import ( INotifee, ) +from libp2p.tools.async_service import ( + background_trio_service, +) from libp2p.tools.constants import ( LISTEN_MADDR, ) diff --git a/tests/core/protocol_muxer/test_protocol_muxer.py b/tests/core/protocol_muxer/test_protocol_muxer.py index ce6be7ac..0e09f061 100644 --- a/tests/core/protocol_muxer/test_protocol_muxer.py +++ b/tests/core/protocol_muxer/test_protocol_muxer.py @@ -1,4 +1,7 @@ import pytest +from trio.testing import ( + RaisesGroup, +) from libp2p.host.exceptions import ( StreamFailure, @@ -58,7 +61,13 @@ async def test_single_protocol_succeeds(security_protocol): @pytest.mark.trio async def test_single_protocol_fails(security_protocol): - with pytest.raises(StreamFailure): + # using trio.testing.RaisesGroup b/c pytest.raises does not handle ExceptionGroups + # yet: https://github.com/pytest-dev/pytest/issues/11538 + # but switch to that once they do + + # the StreamFailure is within 2 nested ExceptionGroups, so we use strict=False + # to unwrap down to the core Exception + with RaisesGroup(StreamFailure, strict=False): await perform_simple_test( "", [PROTOCOL_ECHO], [PROTOCOL_POTATO], security_protocol ) @@ -96,7 +105,14 @@ async def test_multiple_protocol_second_is_valid_succeeds(security_protocol): async def test_multiple_protocol_fails(security_protocol): protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO, "/bar/1.0.0"] protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"] - with pytest.raises(StreamFailure): + + # using trio.testing.RaisesGroup b/c pytest.raises does not handle ExceptionGroups + # yet: https://github.com/pytest-dev/pytest/issues/11538 + # but switch to that once they do + + # the StreamFailure is within 2 nested ExceptionGroups, so we use strict=False + # to unwrap down to the core Exception + with RaisesGroup(StreamFailure, strict=False): await perform_simple_test( "", protocols_for_client, protocols_for_listener, security_protocol ) diff --git a/tests/core/tools/async_service/test_trio_based_service.py b/tests/core/tools/async_service/test_trio_based_service.py new file mode 100644 index 00000000..2039d917 --- /dev/null +++ b/tests/core/tools/async_service/test_trio_based_service.py @@ -0,0 +1,668 @@ +import sys + +if sys.version_info >= (3, 11): + from builtins import ( + ExceptionGroup, + ) +else: + from exceptiongroup import ( + ExceptionGroup, + ) + +import pytest +import trio +from trio.testing import ( + Matcher, + RaisesGroup, +) + +from libp2p.tools.async_service import ( + DaemonTaskExit, + LifecycleError, + Service, + TrioManager, + as_service, + background_trio_service, +) + + +class WaitCancelledService(Service): + async def run(self) -> None: + await self.manager.wait_finished() + + +async def do_service_lifecycle_check( + manager, manager_run_fn, trigger_exit_condition_fn, should_be_cancelled +): + async with trio.open_nursery() as nursery: + assert manager.is_started is False + assert manager.is_running is False + assert manager.is_cancelled is False + assert manager.is_finished is False + + nursery.start_soon(manager_run_fn) + + with trio.fail_after(0.1): + await manager.wait_started() + + assert manager.is_started is True + assert manager.is_running is True + assert manager.is_cancelled is False + assert manager.is_finished is False + + # trigger the service to exit + trigger_exit_condition_fn() + + with trio.fail_after(0.1): + await manager.wait_finished() + + if should_be_cancelled: + assert manager.is_started is True + # We cannot determine whether the service should be running at this + # stage because a service is considered running until it is + # finished. Since it may be cancelled but still not finished we + # can't know. + assert manager.is_cancelled is True + # We also cannot determine whether a service should be finished at this + # stage as it could have exited cleanly and is now finished or it + # might be doing some cleanup after which it will register as being + # finished. + assert manager.is_running is True or manager.is_finished is True + + assert manager.is_started is True + assert manager.is_running is False + assert manager.is_cancelled is should_be_cancelled + assert manager.is_finished is True + + +def test_service_manager_initial_state(): + service = WaitCancelledService() + manager = TrioManager(service) + + assert manager.is_started is False + assert manager.is_running is False + assert manager.is_cancelled is False + assert manager.is_finished is False + + +@pytest.mark.trio +async def test_trio_service_lifecycle_run_and_clean_exit(): + trigger_exit = trio.Event() + + @as_service + async def ServiceTest(manager): + await trigger_exit.wait() + + service = ServiceTest() + manager = TrioManager(service) + + await do_service_lifecycle_check( + manager=manager, + manager_run_fn=manager.run, + trigger_exit_condition_fn=trigger_exit.set, + should_be_cancelled=False, + ) + + +@pytest.mark.trio +async def test_trio_service_lifecycle_run_and_external_cancellation(): + @as_service + async def ServiceTest(manager): + await trio.sleep_forever() + + service = ServiceTest() + manager = TrioManager(service) + + await do_service_lifecycle_check( + manager=manager, + manager_run_fn=manager.run, + trigger_exit_condition_fn=manager.cancel, + should_be_cancelled=True, + ) + + +@pytest.mark.trio +async def test_trio_service_lifecycle_run_and_exception(): + trigger_error = trio.Event() + + @as_service + async def ServiceTest(manager): + await trigger_error.wait() + raise RuntimeError("Service throwing error") + + service = ServiceTest() + manager = TrioManager(service) + + async def do_service_run(): + with RaisesGroup( + Matcher(RuntimeError, match="Service throwing error"), strict=False + ): + await manager.run() + + await do_service_lifecycle_check( + manager=manager, + manager_run_fn=do_service_run, + trigger_exit_condition_fn=trigger_error.set, + should_be_cancelled=True, + ) + + +@pytest.mark.trio +async def test_trio_service_lifecycle_run_and_task_exception(): + trigger_error = trio.Event() + + @as_service + async def ServiceTest(manager): + async def task_fn(): + await trigger_error.wait() + raise RuntimeError("Service throwing error") + + manager.run_task(task_fn) + + service = ServiceTest() + manager = TrioManager(service) + + async def do_service_run(): + with RaisesGroup( + Matcher(RuntimeError, match="Service throwing error"), strict=False + ): + await manager.run() + + await do_service_lifecycle_check( + manager=manager, + manager_run_fn=do_service_run, + trigger_exit_condition_fn=trigger_error.set, + should_be_cancelled=True, + ) + + +@pytest.mark.trio +async def test_sub_service_cancelled_when_parent_stops(): + ready_cancel = trio.Event() + + # This test runs a service that runs a sub-service that sleeps forever. When the + # parent exits, the sub-service should be cancelled as well. + @as_service + async def WaitForeverService(manager): + ready_cancel.set() + await manager.wait_finished() + + sub_manager = TrioManager(WaitForeverService()) + + @as_service + async def ServiceTest(manager): + async def run_sub(): + await sub_manager.run() + + manager.run_task(run_sub) + await manager.wait_finished() + + s = ServiceTest() + async with background_trio_service(s) as manager: + await ready_cancel.wait() + + assert not manager.is_running + assert manager.is_cancelled + assert manager.is_finished + + assert not sub_manager.is_running + assert not sub_manager.is_cancelled + assert sub_manager.is_finished + + +@pytest.mark.trio +async def test_trio_service_lifecycle_run_and_daemon_task_exit(): + trigger_error = trio.Event() + + @as_service + async def ServiceTest(manager): + async def daemon_task_fn(): + await trigger_error.wait() + + manager.run_daemon_task(daemon_task_fn) + await manager.wait_finished() + + service = ServiceTest() + manager = TrioManager(service) + + async def do_service_run(): + with RaisesGroup(Matcher(DaemonTaskExit, match="Daemon task"), strict=False): + await manager.run() + + await do_service_lifecycle_check( + manager=manager, + manager_run_fn=do_service_run, + trigger_exit_condition_fn=trigger_error.set, + should_be_cancelled=True, + ) + + +@pytest.mark.trio +async def test_exceptiongroup_in_run(): + # This test should cause TrioManager.run() to explicitly raise an ExceptionGroup + # containing two exceptions -- one raised inside its run() method and another + # raised by the daemon task exiting early. + trigger_error = trio.Event() + + class ServiceTest(Service): + async def run(self): + ready = trio.Event() + self.manager.run_task(self.error_fn, ready) + await ready.wait() + trigger_error.set() + raise RuntimeError("Exception inside Service.run()") + + async def error_fn(self, ready): + ready.set() + await trigger_error.wait() + raise ValueError("Exception inside error_fn") + + with pytest.raises(ExceptionGroup) as exc_info: + await TrioManager.run_service(ServiceTest()) + + exc = exc_info.value + assert len(exc.exceptions) == 2 + assert any(isinstance(err, RuntimeError) for err in exc.exceptions) + assert any(isinstance(err, ValueError) for err in exc.exceptions) + + +@pytest.mark.trio +async def test_trio_service_background_service_context_manager(): + service = WaitCancelledService() + + async with background_trio_service(service) as manager: + # ensure the manager property is set. + assert hasattr(service, "manager") + assert service.get_manager() is manager + + assert manager.is_started is True + assert manager.is_running is True + assert manager.is_cancelled is False + assert manager.is_finished is False + + assert manager.is_started is True + assert manager.is_running is False + assert manager.is_cancelled is True + assert manager.is_finished is True + + +@pytest.mark.trio +async def test_trio_service_manager_stop(): + service = WaitCancelledService() + + async with background_trio_service(service) as manager: + assert manager.is_started is True + assert manager.is_running is True + assert manager.is_cancelled is False + assert manager.is_finished is False + + await manager.stop() + + assert manager.is_started is True + assert manager.is_running is False + assert manager.is_cancelled is True + assert manager.is_finished is True + + +@pytest.mark.trio +async def test_trio_service_manager_run_task(): + task_event = trio.Event() + + @as_service + async def RunTaskService(manager): + async def task_fn(): + task_event.set() + + manager.run_task(task_fn) + await manager.wait_finished() + + async with background_trio_service(RunTaskService()): + with trio.fail_after(0.1): + await task_event.wait() + + +@pytest.mark.trio +async def test_trio_service_manager_run_task_waits_for_task_completion(): + task_event = trio.Event() + + @as_service + async def RunTaskService(manager): + async def task_fn(): + await trio.sleep(0.01) + task_event.set() + + manager.run_task(task_fn) + # the task is set to run in the background but then the service exits. + # We want to be sure that the task is allowed to continue till + # completion unless explicitely cancelled. + + async with background_trio_service(RunTaskService()): + with trio.fail_after(0.1): + await task_event.wait() + + +@pytest.mark.trio +async def test_trio_service_manager_run_task_can_still_cancel_after_run_finishes(): + task_event = trio.Event() + service_finished = trio.Event() + + @as_service + async def RunTaskService(manager): + async def task_fn(): + # this will never complete + await task_event.wait() + + manager.run_task(task_fn) + # the task is set to run in the background but then the service exits. + # We want to be sure that the task is allowed to continue till + # completion unless explicitely cancelled. + service_finished.set() + + async with background_trio_service(RunTaskService()) as manager: + with trio.fail_after(0.01): + await service_finished.wait() + + # show that the service hangs waiting for the task to complete. + with trio.move_on_after(0.01) as cancel_scope: + await manager.wait_finished() + assert cancel_scope.cancelled_caught is True + + # trigger cancellation and see that the service actually stops + manager.cancel() + with trio.fail_after(0.01): + await manager.wait_finished() + + +@pytest.mark.trio +async def test_trio_service_manager_run_task_reraises_exceptions(): + task_event = trio.Event() + + @as_service + async def RunTaskService(manager): + async def task_fn(): + await task_event.wait() + raise Exception("task exception in run_task") + + manager.run_task(task_fn) + with trio.fail_after(1): + await trio.sleep_forever() + + with RaisesGroup( + Matcher(Exception, match="task exception in run_task"), strict=False + ): + async with background_trio_service(RunTaskService()): + task_event.set() + with trio.fail_after(1): + await trio.sleep_forever() + + +@pytest.mark.trio +async def test_trio_service_manager_run_daemon_task_cancels_if_exits(): + task_event = trio.Event() + + @as_service + async def RunTaskService(manager): + async def daemon_task_fn(): + await task_event.wait() + + manager.run_daemon_task(daemon_task_fn, name="daemon_task_fn") + with trio.fail_after(1): + await trio.sleep_forever() + + with RaisesGroup( + Matcher( + DaemonTaskExit, match=r"Daemon task daemon_task_fn\[daemon=True\] exited" + ), + strict=False, + ): + async with background_trio_service(RunTaskService()): + task_event.set() + with trio.fail_after(1): + await trio.sleep_forever() + + +@pytest.mark.trio +async def test_trio_service_manager_propogates_and_records_exceptions(): + @as_service + async def ThrowErrorService(manager): + raise RuntimeError("this is the error") + + service = ThrowErrorService() + manager = TrioManager(service) + + assert manager.did_error is False + + with RaisesGroup(Matcher(RuntimeError, match="this is the error"), strict=False): + await manager.run() + + assert manager.did_error is True + + +@pytest.mark.trio +async def test_trio_service_lifecycle_run_and_clean_exit_with_child_service(): + trigger_exit = trio.Event() + + @as_service + async def ChildServiceTest(manager): + await trigger_exit.wait() + + @as_service + async def ServiceTest(manager): + child_manager = manager.run_child_service(ChildServiceTest()) + await child_manager.wait_started() + + service = ServiceTest() + manager = TrioManager(service) + + await do_service_lifecycle_check( + manager=manager, + manager_run_fn=manager.run, + trigger_exit_condition_fn=trigger_exit.set, + should_be_cancelled=False, + ) + + +@pytest.mark.trio +async def test_trio_service_with_daemon_child_service(): + ready = trio.Event() + + @as_service + async def ChildServiceTest(manager): + await manager.wait_finished() + + @as_service + async def ServiceTest(manager): + child_manager = manager.run_daemon_child_service(ChildServiceTest()) + await child_manager.wait_started() + ready.set() + await manager.wait_finished() + + service = ServiceTest() + async with background_trio_service(service): + await ready.wait() + + +@pytest.mark.trio +async def test_trio_service_with_daemon_child_task(): + ready = trio.Event() + started = trio.Event() + + async def _task(): + started.set() + await trio.sleep(100) + + @as_service + async def ServiceTest(manager): + manager.run_daemon_task(_task) + await started.wait() + ready.set() + await manager.wait_finished() + + service = ServiceTest() + async with background_trio_service(service): + await ready.wait() + + +@pytest.mark.trio +async def test_trio_service_with_async_generator(): + is_within_agen = trio.Event() + + async def do_agen(): + while True: + yield + + @as_service + async def ServiceTest(manager): + async for _ in do_agen(): # noqa: F841 + await trio.lowlevel.checkpoint() + is_within_agen.set() + + async with background_trio_service(ServiceTest()) as manager: + await is_within_agen.wait() + manager.cancel() + + +@pytest.mark.trio +async def test_trio_service_disallows_task_scheduling_when_not_running(): + class ServiceTest(Service): + async def run(self): + await self.manager.wait_finished() + + def do_schedule(self): + self.manager.run_task(trio.sleep, 1) + + service = ServiceTest() + + async with background_trio_service(service): + service.do_schedule() + + with pytest.raises(LifecycleError): + service.do_schedule() + + +@pytest.mark.trio +async def test_trio_service_disallows_task_scheduling_after_cancel(): + @as_service + async def ServiceTest(manager): + manager.cancel() + manager.run_task(trio.sleep, 1) + + await TrioManager.run_service(ServiceTest()) + + +@pytest.mark.trio +async def test_trio_service_cancellation_with_running_daemon_task(): + in_daemon = trio.Event() + + class ServiceTest(Service): + async def run(self): + self.manager.run_daemon_task(self._do_daemon) + await self.manager.wait_finished() + + async def _do_daemon(self): + in_daemon.set() + while self.manager.is_running: + await trio.lowlevel.checkpoint() + + async with background_trio_service(ServiceTest()) as manager: + await in_daemon.wait() + manager.cancel() + + +@pytest.mark.trio +async def test_trio_service_with_try_finally_cleanup(): + ready_cancel = trio.Event() + + class TryFinallyService(Service): + cleanup_up = False + + async def run(self) -> None: + try: + ready_cancel.set() + await self.manager.wait_finished() + finally: + self.cleanup_up = True + + service = TryFinallyService() + async with background_trio_service(service) as manager: + await ready_cancel.wait() + assert not service.cleanup_up + manager.cancel() + assert service.cleanup_up + + +@pytest.mark.trio +async def test_trio_service_with_try_finally_cleanup_with_unshielded_await(): + ready_cancel = trio.Event() + + class TryFinallyService(Service): + cleanup_up = False + + async def run(self) -> None: + try: + ready_cancel.set() + await self.manager.wait_finished() + finally: + await trio.lowlevel.checkpoint() + self.cleanup_up = True + + service = TryFinallyService() + async with background_trio_service(service) as manager: + await ready_cancel.wait() + assert not service.cleanup_up + manager.cancel() + assert not service.cleanup_up + + +@pytest.mark.trio +async def test_trio_service_with_try_finally_cleanup_with_shielded_await(): + ready_cancel = trio.Event() + + class TryFinallyService(Service): + cleanup_up = False + + async def run(self) -> None: + try: + ready_cancel.set() + await self.manager.wait_finished() + finally: + with trio.CancelScope(shield=True): + await trio.lowlevel.checkpoint() + self.cleanup_up = True + + service = TryFinallyService() + async with background_trio_service(service) as manager: + await ready_cancel.wait() + assert not service.cleanup_up + manager.cancel() + assert service.cleanup_up + + +@pytest.mark.trio +async def test_error_in_service_run(): + class ServiceTest(Service): + async def run(self): + self.manager.run_daemon_task(self.manager.wait_finished) + raise ValueError("Exception inside run()") + + with RaisesGroup(ValueError, strict=False): + await TrioManager.run_service(ServiceTest()) + + +@pytest.mark.trio +async def test_daemon_task_finishes_leaving_children(): + class ServiceTest(Service): + async def sleep_and_fail(self): + await trio.sleep(1) + raise AssertionError( + "This should not happen as the task should be cancelled" + ) + + async def buggy_daemon(self): + self.manager.run_task(self.sleep_and_fail) + + async def run(self): + self.manager.run_daemon_task(self.buggy_daemon) + + with RaisesGroup(DaemonTaskExit, strict=False): + await TrioManager.run_service(ServiceTest()) diff --git a/tests/core/tools/async_service/test_trio_external_api.py b/tests/core/tools/async_service/test_trio_external_api.py new file mode 100644 index 00000000..dae2f14a --- /dev/null +++ b/tests/core/tools/async_service/test_trio_external_api.py @@ -0,0 +1,109 @@ +# Copied from https://github.com/ethereum/async-service +import pytest +import trio +from trio.testing import ( + RaisesGroup, +) + +from libp2p.tools.async_service import ( + LifecycleError, + Service, + background_trio_service, +) +from libp2p.tools.async_service.trio_service import ( + external_api, +) + + +class ExternalAPIService(Service): + async def run(self): + await self.manager.wait_finished() + + @external_api + async def get_7(self, wait_return=None, signal_event=None): + if signal_event is not None: + signal_event.set() + if wait_return is not None: + await wait_return.wait() + return 7 + + +@pytest.mark.trio +async def test_trio_service_external_api_fails_before_start(): + service = ExternalAPIService() + + # should raise if the service has not yet been started. + with pytest.raises(LifecycleError): + await service.get_7() + + +@pytest.mark.trio +async def test_trio_service_external_api_works_while_running(): + service = ExternalAPIService() + + async with background_trio_service(service): + assert await service.get_7() == 7 + + +@pytest.mark.trio +async def test_trio_service_external_api_raises_when_cancelled(): + service = ExternalAPIService() + + async with background_trio_service(service) as manager: + with RaisesGroup(LifecycleError, strict=False): + async with trio.open_nursery() as nursery: + # an event to ensure that we are indeed within the body of the + is_within_fn = trio.Event() + trigger_return = trio.Event() + + nursery.start_soon(service.get_7, trigger_return, is_within_fn) + + # ensure we're within the body of the task. + await is_within_fn.wait() + + # now cancel the service and trigger the return of the function. + manager.cancel() + + # exiting the context block here will cause the background task + # to complete and shold raise the exception + + # A direct call should also fail. This *should* be hitting the early + # return mechanism. + with pytest.raises(LifecycleError): + assert await service.get_7() + + +@pytest.mark.trio +async def test_trio_service_external_api_raises_when_finished(): + service = ExternalAPIService() + + async with background_trio_service(service) as manager: + pass + + assert manager.is_finished + # A direct call should also fail. This *should* be hitting the early + # return mechanism. + with pytest.raises(LifecycleError): + assert await service.get_7() + + +@pytest.mark.trio +async def test_trio_external_api_call_that_schedules_task(): + done = trio.Event() + + class MyService(Service): + async def run(self): + await self.manager.wait_finished() + + @external_api + async def do_scheduling(self): + self.manager.run_task(self.set_done) + + async def set_done(self): + done.set() + + service = MyService() + async with background_trio_service(service): + await service.do_scheduling() + with trio.fail_after(1): + await done.wait() diff --git a/tests/core/tools/async_service/test_trio_manager_stats.py b/tests/core/tools/async_service/test_trio_manager_stats.py new file mode 100644 index 00000000..659b2f8d --- /dev/null +++ b/tests/core/tools/async_service/test_trio_manager_stats.py @@ -0,0 +1,86 @@ +import pytest +import trio + +from libp2p.tools.async_service import ( + Service, + background_trio_service, +) + + +@pytest.mark.trio +async def test_trio_manager_stats(): + ready = trio.Event() + + class StatsTest(Service): + async def run(self): + # 2 that run forever + self.manager.run_task(trio.sleep_forever) + self.manager.run_task(trio.sleep_forever) + + # 2 that complete + self.manager.run_task(trio.lowlevel.checkpoint) + self.manager.run_task(trio.lowlevel.checkpoint) + + # 1 that spawns some children + self.manager.run_task(self.run_with_children, 4) + + async def run_with_children(self, num_children): + for _ in range(num_children): + self.manager.run_task(trio.sleep_forever) + ready.set() + + def run_external_root(self): + self.manager.run_task(trio.lowlevel.checkpoint) + + service = StatsTest() + async with background_trio_service(service) as manager: + service.run_external_root() + assert len(manager._root_tasks) == 2 + with trio.fail_after(1): + await ready.wait() + + # we need to yield to the event loop a few times to allow the various + # tasks to schedule themselves and get running. + for _ in range(50): + await trio.lowlevel.checkpoint() + + assert manager.stats.tasks.total_count == 10 + assert manager.stats.tasks.finished_count == 3 + assert manager.stats.tasks.pending_count == 7 + + # This is a simple test to ensure that finished tasks are removed from + # tracking to prevent unbounded memory growth. + assert len(manager._root_tasks) == 1 + + # now check after exiting + assert manager.stats.tasks.total_count == 10 + assert manager.stats.tasks.finished_count == 10 + assert manager.stats.tasks.pending_count == 0 + + +@pytest.mark.trio +async def test_trio_manager_stats_does_not_count_main_run_method(): + ready = trio.Event() + + class StatsTest(Service): + async def run(self): + self.manager.run_task(trio.sleep_forever) + ready.set() + + async with background_trio_service(StatsTest()) as manager: + with trio.fail_after(1): + await ready.wait() + + # we need to yield to the event loop a few times to allow the various + # tasks to schedule themselves and get running. + for _ in range(10): + await trio.lowlevel.checkpoint() + + assert manager.stats.tasks.total_count == 1 + assert manager.stats.tasks.finished_count == 0 + assert manager.stats.tasks.pending_count == 1 + + # now check after exiting + assert manager.stats.tasks.total_count == 1 + assert manager.stats.tasks.finished_count == 1 + assert manager.stats.tasks.pending_count == 0