drop async-service dep and copy relevant code into a local async_service

tool, updated for modern handling of ExceptionGroup
This commit is contained in:
pacrob
2024-05-19 14:48:03 -06:00
committed by Paul Robinson
parent 7de6cbaab0
commit d9b92635c1
28 changed files with 2176 additions and 35 deletions

View File

@ -49,6 +49,7 @@ repos:
- id: mypy
additional_dependencies:
- mypy-protobuf
- trio-typing
exclude: 'tests/'
- repo: local
hooks:

View File

@ -0,0 +1,61 @@
libp2p.tools.async\_service package
===================================
Submodules
----------
libp2p.tools.async\_service.abc module
--------------------------------------
.. automodule:: libp2p.tools.async_service.abc
:members:
:undoc-members:
:show-inheritance:
libp2p.tools.async\_service.base module
---------------------------------------
.. automodule:: libp2p.tools.async_service.base
:members:
:undoc-members:
:show-inheritance:
libp2p.tools.async\_service.exceptions module
---------------------------------------------
.. automodule:: libp2p.tools.async_service.exceptions
:members:
:undoc-members:
:show-inheritance:
libp2p.tools.async\_service.stats module
----------------------------------------
.. automodule:: libp2p.tools.async_service.stats
:members:
:undoc-members:
:show-inheritance:
libp2p.tools.async\_service.trio\_service module
------------------------------------------------
.. automodule:: libp2p.tools.async_service.trio_service
:members:
:undoc-members:
:show-inheritance:
libp2p.tools.async\_service.typing module
-----------------------------------------
.. automodule:: libp2p.tools.async_service.typing
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: libp2p.tools.async_service
:members:
:undoc-members:
:show-inheritance:

View File

@ -7,6 +7,7 @@ Subpackages
.. toctree::
:maxdepth: 4
libp2p.tools.async_service
libp2p.tools.pubsub
Submodules

View File

@ -9,9 +9,6 @@ from typing import (
Sequence,
)
from async_service import (
background_trio_service,
)
import multiaddr
from libp2p.crypto.keys import (
@ -52,6 +49,9 @@ from libp2p.protocol_muxer.multiselect_client import (
from libp2p.protocol_muxer.multiselect_communicator import (
MultiselectCommunicator,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from libp2p.typing import (
StreamHandlerFn,
TProtocol,

View File

@ -8,9 +8,6 @@ from typing import (
Sequence,
)
from async_service import (
ServiceAPI,
)
from multiaddr import (
Multiaddr,
)
@ -24,6 +21,9 @@ from libp2p.peer.id import (
from libp2p.peer.peerstore_interface import (
IPeerStore,
)
from libp2p.tools.async_service import (
ServiceAPI,
)
from libp2p.transport.listener_interface import (
IListener,
)

View File

@ -5,9 +5,6 @@ from typing import (
Optional,
)
from async_service import (
Service,
)
from multiaddr import (
Multiaddr,
)
@ -31,6 +28,9 @@ from libp2p.peer.peerstore_interface import (
from libp2p.stream_muxer.abc import (
IMuxedConn,
)
from libp2p.tools.async_service import (
Service,
)
from libp2p.transport.exceptions import (
MuxerUpgradeFailure,
OpenConnectionError,

View File

@ -32,7 +32,7 @@ class MultiselectClient(IMultiselectClient):
Ensure that the client and multiselect are both using the same
multiselect protocol.
:param stream: stream to communicate with multiselect over
:param communicator: communicator to use to communicate with counterparty
:raise MultiselectClientError: raised when handshake failed
"""
try:
@ -57,7 +57,7 @@ class MultiselectClient(IMultiselectClient):
protocol that multiselect agrees on (i.e. that multiselect selects)
:param protocol: protocol to select
:param stream: stream to communicate with multiselect over
:param communicator: communicator to use to communicate with counterparty
:return: selected protocol
:raise MultiselectClientError: raised when protocol negotiation failed
"""

View File

@ -11,13 +11,12 @@ from typing import (
Tuple,
)
from async_service import (
ServiceAPI,
)
from libp2p.peer.id import (
ID,
)
from libp2p.tools.async_service import (
ServiceAPI,
)
from libp2p.typing import (
TProtocol,
)

View File

@ -17,9 +17,6 @@ from typing import (
Tuple,
)
from async_service import (
Service,
)
import trio
from libp2p.network.stream.exceptions import (
@ -31,6 +28,9 @@ from libp2p.peer.id import (
from libp2p.pubsub import (
floodsub,
)
from libp2p.tools.async_service import (
Service,
)
from libp2p.typing import (
TProtocol,
)

View File

@ -15,9 +15,6 @@ from typing import (
cast,
)
from async_service import (
Service,
)
import base58
from lru import (
LRU,
@ -51,6 +48,9 @@ from libp2p.network.stream.net_stream_interface import (
from libp2p.peer.id import (
ID,
)
from libp2p.tools.async_service import (
Service,
)
from libp2p.typing import (
TProtocol,
)

View File

@ -0,0 +1,15 @@
from .abc import (
ServiceAPI,
)
from .base import (
Service,
as_service,
)
from .exceptions import (
DaemonTaskExit,
LifecycleError,
)
from .trio_service import (
TrioManager,
background_trio_service,
)

View File

@ -0,0 +1,41 @@
# Copied from https://github.com/ethereum/async-service
import os
from typing import (
Any,
)
def get_task_name(value: Any, explicit_name: str = None) -> str:
# inline import to ensure `_utils` is always importable from the rest of
# the module.
from .abc import ( # noqa: F401
ServiceAPI,
)
if explicit_name is not None:
# if an explicit name was provided, just return that.
return explicit_name
elif isinstance(value, ServiceAPI):
# `Service` instance naming rules:
#
# 1. __str__ **if** the class implements a custom __str__ method
# 2. __repr__ **if** the class implements a custom __repr__ method
# 3. The `Service` class name.
value_cls = type(value)
if value_cls.__str__ is not object.__str__:
return str(value)
if value_cls.__repr__ is not object.__repr__:
return repr(value)
else:
return value.__class__.__name__
else:
try:
# Prefer the name of the function if it has one
return str(value.__name__) # mypy doesn't know __name__ is a `str`
except AttributeError:
return repr(value)
def is_verbose_logging_enabled() -> bool:
return bool(os.environ.get("ASYNC_SERVICE_VERBOSE_LOG", False))

View File

@ -0,0 +1,257 @@
# Copied from https://github.com/ethereum/async-service
from abc import (
ABC,
abstractmethod,
)
from typing import (
Any,
Hashable,
Optional,
Set,
)
import trio_typing
from .stats import (
Stats,
)
from .typing import (
AsyncFn,
)
class TaskAPI(Hashable):
name: str
daemon: bool
parent: Optional["TaskWithChildrenAPI"]
@abstractmethod
async def run(self) -> None:
...
@abstractmethod
async def cancel(self) -> None:
...
@property
@abstractmethod
def is_done(self) -> bool:
...
@abstractmethod
async def wait_done(self) -> None:
...
class TaskWithChildrenAPI(TaskAPI):
children: Set[TaskAPI]
@abstractmethod
def add_child(self, child: TaskAPI) -> None:
...
@abstractmethod
def discard_child(self, child: TaskAPI) -> None:
...
class ServiceAPI(ABC):
_manager: "InternalManagerAPI"
@abstractmethod
def get_manager(self) -> "ManagerAPI":
"""
External retrieval of the manager for this service.
Will raise a :class:`~async_service.exceptions.LifecycleError` if the
service does not yet have a `manager` assigned to it.
"""
...
@abstractmethod
async def run(self) -> None:
"""
Primary entry point for all service logic.
.. note:: This method should **not** be directly invoked by user code.
Services may be run using the following approaches.
.. code-block: python
# 1. run the service in the background using a context manager
async with run_service(service) as manager:
# service runs inside context block
...
# service cancels and stops when context exits
# service will have fully stopped
# 2. run the service blocking until completion
await Manager.run_service(service)
# 3. create manager and then run service blocking until completion
manager = Manager(service)
await manager.run()
"""
...
class ManagerAPI(ABC):
@property
@abstractmethod
def is_started(self) -> bool:
"""
Return boolean indicating if the underlying service has been started.
"""
...
@property
@abstractmethod
def is_running(self) -> bool:
"""
Return boolean indicating if the underlying service is actively
running.
A service is considered running if it has been started and
has not yet been stopped.
"""
...
@property
@abstractmethod
def is_cancelled(self) -> bool:
"""
Return boolean indicating if the underlying service has been cancelled.
This can occure externally via the `cancel()` method or internally due
to a task crash or a crash of the actual :meth:`ServiceAPI.run` method.
"""
...
@property
@abstractmethod
def is_finished(self) -> bool:
"""
Return boolean indicating if the underlying service is stopped.
A stopped service will have completed all of the background tasks.
"""
...
@property
@abstractmethod
def did_error(self) -> bool:
"""
Return boolean indicating if the underlying service threw an exception.
"""
...
@abstractmethod
def cancel(self) -> None:
"""
Trigger cancellation of the service.
"""
...
@abstractmethod
async def stop(self) -> None:
"""
Trigger cancellation of the service and wait for it to finish.
"""
...
@abstractmethod
async def wait_started(self) -> None:
"""
Wait until the service is started.
"""
...
@abstractmethod
async def wait_finished(self) -> None:
"""
Wait until the service is stopped.
"""
...
@classmethod
@abstractmethod
async def run_service(cls, service: ServiceAPI) -> None:
"""
Run a service
"""
...
@abstractmethod
async def run(self) -> None:
"""
Run a service
"""
...
@property
@abstractmethod
def stats(self) -> Stats:
"""
Return a stats object with details about the service.
"""
...
class InternalManagerAPI(ManagerAPI):
"""
Defines the API that the `Service.manager` property exposes.
The InternalManagerAPI / ManagerAPI distinction is in place to ensure that
external callers to a service do not try to use the task scheduling
functionality as it is only designed to be used internally.
"""
@trio_typing.takes_callable_and_args
@abstractmethod
def run_task(
self, async_fn: AsyncFn, *args: Any, daemon: bool = False, name: str = None
) -> None:
"""
Run a task in the background. If the function throws an exception it
will trigger the service to be cancelled and be propogated.
If `daemon == True` then the the task is expected to run indefinitely
and will trigger cancellation if the task finishes.
"""
...
@trio_typing.takes_callable_and_args
@abstractmethod
def run_daemon_task(self, async_fn: AsyncFn, *args: Any, name: str = None) -> None:
"""
Run a daemon task in the background.
Equivalent to `run_task(..., daemon=True)`.
"""
...
@abstractmethod
def run_child_service(
self, service: ServiceAPI, daemon: bool = False, name: str = None
) -> "ManagerAPI":
"""
Run a service in the background. If the function throws an exception it
will trigger the parent service to be cancelled and be propogated.
If `daemon == True` then the the service is expected to run indefinitely
and will trigger cancellation if the service finishes.
"""
...
@abstractmethod
def run_daemon_child_service(
self, service: ServiceAPI, name: str = None
) -> "ManagerAPI":
"""
Run a daemon service in the background.
Equivalent to `run_child_service(..., daemon=True)`.
"""
...

View File

@ -0,0 +1,378 @@
# Copied from https://github.com/ethereum/async-service
from abc import (
abstractmethod,
)
import asyncio
from collections import (
Counter,
)
import logging
import sys
from typing import (
Any,
Awaitable,
Callable,
Iterable,
List,
Optional,
Sequence,
Set,
Type,
TypeVar,
cast,
)
import uuid
from ._utils import (
is_verbose_logging_enabled,
)
from .abc import (
InternalManagerAPI,
ManagerAPI,
ServiceAPI,
TaskAPI,
TaskWithChildrenAPI,
)
from .exceptions import (
DaemonTaskExit,
LifecycleError,
TooManyChildrenException,
)
from .stats import (
Stats,
TaskStats,
)
from .typing import (
EXC_INFO,
AsyncFn,
)
MAX_CHILDREN_TASKS = 1000
class Service(ServiceAPI):
def __str__(self) -> str:
return self.__class__.__name__
@property
def manager(self) -> "InternalManagerAPI":
"""
Expose the manager as a property here intead of
:class:`async_service.abc.ServiceAPI` to ensure that anyone using
proper type hints will not have access to this property since it isn't
part of that API, while still allowing all subclasses of the
:class:`async_service.base.Service` to access this property directly.
"""
return self._manager
def get_manager(self) -> ManagerAPI:
try:
return self._manager
except AttributeError:
raise LifecycleError(
"Service does not have a manager assigned to it. Are you sure "
"it is running?"
)
LogicFnType = Callable[..., Awaitable[Any]]
def as_service(service_fn: LogicFnType) -> Type[ServiceAPI]:
"""
Create a service out of a simple function
"""
class _Service(Service):
def __init__(self, *args: Any, **kwargs: Any):
self._args = args
self._kwargs = kwargs
async def run(self) -> None:
await service_fn(self.manager, *self._args, **self._kwargs)
_Service.__name__ = service_fn.__name__
_Service.__doc__ = service_fn.__doc__
return _Service
class BaseTask(TaskAPI):
def __init__(
self, name: str, daemon: bool, parent: Optional[TaskWithChildrenAPI]
) -> None:
# meta
self.name = name
self.daemon = daemon
# parent task
self.parent = parent
# For hashable interface.
self._id = uuid.uuid4()
def __hash__(self) -> int:
return hash(self._id)
def __eq__(self, other: Any) -> bool:
if isinstance(other, TaskAPI):
return hash(self) == hash(other)
else:
return False
def __str__(self) -> str:
return f"{self.name}[daemon={self.daemon}]"
class BaseTaskWithChildren(BaseTask, TaskWithChildrenAPI):
def __init__(
self, name: str, daemon: bool, parent: Optional[TaskWithChildrenAPI]
) -> None:
super().__init__(name, daemon, parent)
self.children = set()
def add_child(self, child: TaskAPI) -> None:
self.children.add(child)
def discard_child(self, child: TaskAPI) -> None:
self.children.discard(child)
T = TypeVar("T", bound="BaseFunctionTask")
class BaseFunctionTask(BaseTaskWithChildren):
@classmethod
def iterate_tasks(cls: Type[T], *tasks: TaskAPI) -> Iterable[T]:
for task in tasks:
if isinstance(task, cls):
yield task
else:
continue
yield from cls.iterate_tasks(
*(
child_task
for child_task in task.children
if isinstance(child_task, cls)
)
)
def __init__(
self,
name: str,
daemon: bool,
parent: Optional[TaskWithChildrenAPI],
async_fn: AsyncFn,
async_fn_args: Sequence[Any],
) -> None:
super().__init__(name, daemon, parent)
self._async_fn = async_fn
self._async_fn_args = async_fn_args
class BaseChildServiceTask(BaseTask):
_child_service: ServiceAPI
child_manager: ManagerAPI
async def run(self) -> None:
if self.child_manager.is_started:
raise LifecycleError(
f"Child service {self._child_service} has already been started"
)
try:
await self.child_manager.run()
if self.daemon:
raise DaemonTaskExit(f"Daemon task {self} exited")
finally:
if self.parent is not None:
self.parent.discard_child(self)
@property
def is_done(self) -> bool:
return self.child_manager.is_finished
async def wait_done(self) -> None:
if self.child_manager.is_started:
await self.child_manager.wait_finished()
class BaseManager(InternalManagerAPI):
logger = logging.getLogger("async_service.Manager")
_verbose = is_verbose_logging_enabled()
_service: ServiceAPI
_errors: List[EXC_INFO]
def __init__(self, service: ServiceAPI) -> None:
if hasattr(service, "_manager"):
raise LifecycleError("Service already has a manager.")
else:
service._manager = self
self._service = service
# errors
self._errors = []
# tasks
self._root_tasks: Set[TaskAPI] = set()
# stats
self._total_task_count = 0
self._done_task_count = 0
def __str__(self) -> str:
status_flags = "".join(
(
"S" if self.is_started else "s",
"R" if self.is_running else "r",
"C" if self.is_cancelled else "c",
"F" if self.is_finished else "f",
"E" if self.did_error else "e",
)
)
return f"<Manager[{self._service}] flags={status_flags}>"
#
# Event API mirror
#
@property
def is_running(self) -> bool:
return self.is_started and not self.is_finished
@property
def did_error(self) -> bool:
return len(self._errors) > 0
#
# Control API
#
async def stop(self) -> None:
self.cancel()
await self.wait_finished()
#
# Wait API
#
def run_daemon_task(
self, async_fn: Callable[..., Awaitable[Any]], *args: Any, name: str = None
) -> None:
self.run_task(async_fn, *args, daemon=True, name=name)
def run_daemon_child_service(
self, service: ServiceAPI, name: str = None
) -> ManagerAPI:
return self.run_child_service(service, daemon=True, name=name)
@property
def stats(self) -> Stats:
# The `max` call here ensures that if this is called prior to the
# `Service.run` method starting we don't return `-1`
total_count = max(0, self._total_task_count)
# Since we track `Service.run` as a task, the `min` call here ensures
# that when the service is fully done that we don't represent the
# `Service.run` method in this count.
finished_count = min(total_count, self._done_task_count)
return Stats(
tasks=TaskStats(total_count=total_count, finished_count=finished_count)
)
#
# Task Management
#
@abstractmethod
def _schedule_task(self, task: TaskAPI) -> None:
...
def _common_run_task(self, task: TaskAPI) -> None:
if not self.is_running:
raise LifecycleError(
"Tasks may not be scheduled if the service is not running"
)
if self.is_running and self.is_cancelled:
self.logger.debug(
"%s: service is being cancelled. Not running task %s", self, task
)
return
self._add_child_task(task.parent, task)
self._total_task_count += 1
self._schedule_task(task)
def _add_child_task(
self, parent: Optional[TaskWithChildrenAPI], task: TaskAPI
) -> None:
if parent is None:
all_children = self._root_tasks
else:
all_children = parent.children
if len(all_children) > MAX_CHILDREN_TASKS:
task_counter = Counter(map(str, all_children))
raise TooManyChildrenException(
f"Tried to add more than {MAX_CHILDREN_TASKS} child tasks."
f" Most common tasks: {task_counter.most_common(10)}"
)
if parent is None:
if self._verbose:
self.logger.debug("%s: running root task %s", self, task)
self._root_tasks.add(task)
else:
if self._verbose:
self.logger.debug("%s: %s running child task %s", self, parent, task)
parent.add_child(task)
async def _run_and_manage_task(self, task: TaskAPI) -> None:
if self._verbose:
self.logger.debug("%s: task %s running", self, task)
try:
try:
await task.run()
except DaemonTaskExit:
if self.is_cancelled:
pass
else:
raise
finally:
if isinstance(task, TaskWithChildrenAPI):
new_parent = task.parent
for child in task.children:
child.parent = new_parent
self._add_child_task(new_parent, child)
self.logger.debug(
"%s left a child task (%s) behind, reassigning it to %s",
task,
child,
new_parent or "root",
)
except asyncio.CancelledError:
self.logger.debug("%s: task %s raised CancelledError.", self, task)
raise
except Exception as err:
self.logger.error(
"%s: task %s exited with error: %s",
self,
task,
err,
# Only show stacktrace if this is **not** a DaemonTaskExit error
exc_info=not isinstance(err, DaemonTaskExit),
)
self._errors.append(cast(EXC_INFO, sys.exc_info()))
self.cancel()
else:
if task.parent is None:
self._root_tasks.remove(task)
if self._verbose:
self.logger.debug("%s: task %s exited cleanly.", self, task)
finally:
self._done_task_count += 1

View File

@ -0,0 +1,26 @@
# Copied from https://github.com/ethereum/async-service
class ServiceException(Exception):
"""
Base class for Service exceptions
"""
class LifecycleError(ServiceException):
"""
Raised when an action would violate the service lifecycle rules.
"""
class DaemonTaskExit(ServiceException):
"""
Raised when an action would violate the service lifecycle rules.
"""
class TooManyChildrenException(ServiceException):
"""
Raised when a service adds too many children. It is a sign of task leakage
that needs to be prevented.
"""

View File

@ -0,0 +1,18 @@
# Copied from https://github.com/ethereum/async-service
from typing import (
NamedTuple,
)
class TaskStats(NamedTuple):
total_count: int
finished_count: int
@property
def pending_count(self) -> int:
return self.total_count - self.finished_count
class Stats(NamedTuple):
tasks: TaskStats

View File

@ -0,0 +1,446 @@
# Originally copied from https://github.com/ethereum/async-service
from __future__ import (
annotations,
)
from contextlib import (
asynccontextmanager,
)
import functools
import sys
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Optional,
Sequence,
Tuple,
TypeVar,
cast,
)
if sys.version_info >= (3, 11):
from builtins import (
ExceptionGroup,
)
else:
from exceptiongroup import ExceptionGroup
import trio
import trio_typing
from ._utils import (
get_task_name,
)
from .abc import (
ManagerAPI,
ServiceAPI,
TaskAPI,
TaskWithChildrenAPI,
)
from .base import (
BaseChildServiceTask,
BaseFunctionTask,
BaseManager,
)
from .exceptions import (
DaemonTaskExit,
LifecycleError,
)
from .typing import (
EXC_INFO,
AsyncFn,
)
class FunctionTask(BaseFunctionTask):
_trio_task: trio.lowlevel.Task | None = None
def __init__(
self,
name: str,
daemon: bool,
parent: TaskWithChildrenAPI | None,
async_fn: AsyncFn,
async_fn_args: Sequence[Any],
) -> None:
super().__init__(name, daemon, parent, async_fn, async_fn_args)
# We use an event to manually track when the child task is "done".
# This is because trio has no API for awaiting completion of a task.
self._done = trio.Event()
# Each task gets its own `CancelScope` which is how we can manually
# control cancellation order of the task DAG
self._cancel_scope = trio.CancelScope()
#
# Trio specific API
#
@property
def has_trio_task(self) -> bool:
return self._trio_task is not None
@property
def trio_task(self) -> trio.lowlevel.Task:
if self._trio_task is None:
raise LifecycleError("Trio task not set yet")
return self._trio_task
@trio_task.setter
def trio_task(self, value: trio.lowlevel.Task) -> None:
if self._trio_task is not None:
raise LifecycleError(f"Task already set: {self._trio_task}")
self._trio_task = value
#
# Core Task API
#
async def run(self) -> None:
self.trio_task = trio.lowlevel.current_task()
try:
with self._cancel_scope:
await self._async_fn(*self._async_fn_args)
if self.daemon:
raise DaemonTaskExit(f"Daemon task {self} exited")
while self.children:
await tuple(self.children)[0].wait_done()
finally:
self._done.set()
if self.parent is not None:
self.parent.discard_child(self)
async def cancel(self) -> None:
for task in tuple(self.children):
await task.cancel()
self._cancel_scope.cancel()
await self.wait_done()
@property
def is_done(self) -> bool:
return self._done.is_set()
async def wait_done(self) -> None:
await self._done.wait()
class ChildServiceTask(BaseChildServiceTask):
def __init__(
self,
name: str,
daemon: bool,
parent: TaskWithChildrenAPI | None,
child_service: ServiceAPI,
) -> None:
super().__init__(name, daemon, parent)
self._child_service = child_service
self.child_manager = TrioManager(child_service)
async def cancel(self) -> None:
if self.child_manager.is_started:
await self.child_manager.stop()
class TrioManager(BaseManager):
# A nursery for sub tasks and services. This nursery is cancelled if the
# service is cancelled but allowed to exit normally if the service exits.
_task_nursery: trio_typing.Nursery
def __init__(self, service: ServiceAPI) -> None:
super().__init__(service)
# events
self._started = trio.Event()
self._cancelled = trio.Event()
self._finished = trio.Event()
# locks
self._run_lock = trio.Lock()
#
# System Tasks
#
async def _handle_cancelled(self) -> None:
self.logger.debug("%s: _handle_cancelled waiting for cancellation", self)
await self._cancelled.wait()
self.logger.debug("%s: _handle_cancelled triggering task cancellation", self)
# The `_root_tasks` changes size as each task completes itself
# and removes itself from the set. For this reason we iterate over a
# copy of the set.
for task in tuple(self._root_tasks):
await task.cancel()
# This finaly cancellation of the task nursery's cancel scope ensures
# that nothing is left behind and that the service will reliably exit.
self._task_nursery.cancel_scope.cancel()
@classmethod
async def run_service(cls, service: ServiceAPI) -> None:
manager = cls(service)
await manager.run()
async def run(self) -> None:
if self._run_lock.locked():
raise LifecycleError(
"Cannot run a service with the run lock already engaged. "
"Already started?"
)
elif self.is_started:
raise LifecycleError("Cannot run a service which is already started.")
try:
async with self._run_lock:
async with trio.open_nursery() as system_nursery:
system_nursery.start_soon(self._handle_cancelled)
try:
async with trio.open_nursery() as task_nursery:
self._task_nursery = task_nursery
self._started.set()
self.run_task(self._service.run, name="run")
# This is hack to get the task stats correct. We don't want
# to count the `Service.run` method as a task. This is still
# imperfect as it will still count as a completed task when
# it finishes.
self._total_task_count = 0
# ***BLOCKING HERE***
# The code flow will block here until the background tasks
# have completed or cancellation occurs.
except Exception:
# Exceptions from any tasks spawned by our service will be
# caught by trio and raised here, so we store them to report
# together with any others we have already captured.
self._errors.append(cast(EXC_INFO, sys.exc_info()))
finally:
system_nursery.cancel_scope.cancel()
finally:
# We need this inside a finally because a trio.Cancelled exception may be
# raised here and it wouldn't be swalled by the 'except Exception' above.
self._finished.set()
self.logger.debug("%s: finished", self)
# This is outside of the finally block above because we don't want to suppress
# trio.Cancelled or ExceptionGroup exceptions coming directly from trio.
if self.did_error:
raise ExceptionGroup(
"Encountered multiple Exceptions: ",
tuple(
exc_value.with_traceback(exc_tb)
for _, exc_value, exc_tb in self._errors
if isinstance(exc_value, Exception)
),
)
#
# Event API mirror
#
@property
def is_started(self) -> bool:
return self._started.is_set()
@property
def is_cancelled(self) -> bool:
return self._cancelled.is_set()
@property
def is_finished(self) -> bool:
return self._finished.is_set()
#
# Control API
#
def cancel(self) -> None:
if not self.is_started:
raise LifecycleError("Cannot cancel as service which was never started.")
elif not self.is_running:
return
else:
self._cancelled.set()
#
# Wait API
#
async def wait_started(self) -> None:
await self._started.wait()
async def wait_finished(self) -> None:
await self._finished.wait()
def _find_parent_task(
self, trio_task: trio.lowlevel.Task
) -> TaskWithChildrenAPI | None:
"""
Find the :class:`async_service.trio.FunctionTask` instance that corresponds to
the given :class:`trio.lowlevel.Task` instance.
"""
for task in FunctionTask.iterate_tasks(*self._root_tasks):
# Any task that has not had its `trio_task` set can be safely
# skipped as those are still in the process of starting up which
# means that they cannot be the parent task since they will not
# have had a chance to schedule child tasks.
if not task.has_trio_task:
continue
if trio_task is task.trio_task:
return task
else:
# In the case that no tasks match we assume this is a new `root`
# task and return `None` as the parent.
return None
def _schedule_task(self, task: TaskAPI) -> None:
self._task_nursery.start_soon(self._run_and_manage_task, task, name=str(task))
def run_task(
self,
async_fn: Callable[..., Awaitable[Any]],
*args: Any,
daemon: bool = False,
name: str = None,
) -> None:
task = FunctionTask(
name=get_task_name(async_fn, name),
daemon=daemon,
parent=self._find_parent_task(trio.lowlevel.current_task()),
async_fn=async_fn,
async_fn_args=args,
)
self._common_run_task(task)
def run_child_service(
self, service: ServiceAPI, daemon: bool = False, name: str = None
) -> ManagerAPI:
task = ChildServiceTask(
name=get_task_name(service, name),
daemon=daemon,
parent=self._find_parent_task(trio.lowlevel.current_task()),
child_service=service,
)
self._common_run_task(task)
return task.child_manager
TFunc = TypeVar("TFunc", bound=Callable[..., Coroutine[Any, Any, Any]])
_ChannelPayload = Tuple[Optional[Any], Optional[BaseException]]
async def _wait_finished(
service: ServiceAPI,
api_func: Callable[..., Any],
channel: trio.abc.SendChannel[_ChannelPayload],
) -> None:
manager = service.get_manager()
if manager.is_finished:
await channel.send(
(
None,
LifecycleError(
f"Cannot access external API {api_func}. "
f"Service {service} is not running: "
),
)
)
return
await manager.wait_finished()
await channel.send(
(
None,
LifecycleError(
f"Cannot access external API {api_func}. "
f"Service {service} is not running: "
),
)
)
async def _wait_api_fn(
self: ServiceAPI,
api_fn: Callable[..., Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
channel: trio.abc.SendChannel[_ChannelPayload],
) -> None:
try:
result = await api_fn(self, *args, **kwargs)
except Exception:
_, exc_value, exc_tb = sys.exc_info()
if exc_value is None or exc_tb is None:
raise Exception(
"This should be unreachable but acts as a type guard for mypy"
)
await channel.send((None, exc_value.with_traceback(exc_tb)))
else:
await channel.send((result, None))
def external_api(func: TFunc) -> TFunc:
@functools.wraps(func)
async def inner(self: ServiceAPI, *args: Any, **kwargs: Any) -> Any:
if not hasattr(self, "manager"):
raise LifecycleError(
f"Cannot access external API {func}. Service {self} has not been run."
)
manager = self.get_manager()
if not manager.is_running:
raise LifecycleError(
f"Cannot access external API {func}. Service {self} is not running: "
)
channels: tuple[
trio.abc.SendChannel[_ChannelPayload],
trio.abc.ReceiveChannel[_ChannelPayload],
] = trio.open_memory_channel(0)
send_channel, receive_channel = channels
async with trio.open_nursery() as nursery:
# mypy's type hints for start_soon break with this invocation.
nursery.start_soon(
_wait_api_fn, self, func, args, kwargs, send_channel # type: ignore
)
nursery.start_soon(_wait_finished, self, func, send_channel)
result, err = await receive_channel.receive()
nursery.cancel_scope.cancel()
if err is None:
return result
else:
raise err
return cast(TFunc, inner)
@asynccontextmanager
async def background_trio_service(service: ServiceAPI) -> AsyncIterator[ManagerAPI]:
"""
Run a service in the background.
The service is running within the context
block and will be properly cleaned up upon exiting the context block.
"""
async with trio.open_nursery() as nursery:
manager = TrioManager(service)
nursery.start_soon(manager.run)
await manager.wait_started()
try:
yield manager
finally:
await manager.stop()

View File

@ -0,0 +1,16 @@
# Copied from https://github.com/ethereum/async-service
from types import (
TracebackType,
)
from typing import (
Any,
Awaitable,
Callable,
Tuple,
Type,
)
EXC_INFO = Tuple[Type[BaseException], BaseException, TracebackType]
AsyncFn = Callable[..., Awaitable[Any]]

View File

@ -15,9 +15,6 @@ from typing import (
from async_exit_stack import (
AsyncExitStack,
)
from async_service import (
background_trio_service,
)
import factory
from multiaddr import (
Multiaddr,
@ -111,6 +108,9 @@ from libp2p.stream_muxer.mplex.mplex import (
from libp2p.stream_muxer.mplex.mplex_stream import (
MplexStream,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from libp2p.tools.constants import (
GOSSIPSUB_PARAMS,
)

View File

@ -56,7 +56,8 @@ class BaseInteractiveProcess(AbstractInterativeProcess):
async def start(self) -> None:
if self.proc is not None:
return
self.proc = await trio.open_process(
# mypy says that `open_process` is not an attribute of trio, suggests run_process instead. # noqa: E501
self.proc = await trio.open_process( # type: ignore[attr-defined]
[self.cmd] + self.args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Redirect stderr to stdout, which makes parsing easier # noqa: E501

View File

@ -10,10 +10,6 @@ from typing import (
from async_exit_stack import (
AsyncExitStack,
)
from async_service import (
Service,
background_trio_service,
)
from libp2p.host.host_interface import (
IHost,
@ -21,6 +17,10 @@ from libp2p.host.host_interface import (
from libp2p.pubsub.pubsub import (
Pubsub,
)
from libp2p.tools.async_service import (
Service,
background_trio_service,
)
from libp2p.tools.factories import (
PubsubFactory,
)

View File

@ -0,0 +1 @@
Drop dep for unmaintained ``async-service`` and copy relevant functions into a local tool of the same name

View File

@ -62,9 +62,10 @@ install_requires = [
"coincurve>=10.0.0",
"pynacl==1.3.0",
"trio>=0.15.0",
"async-service>=0.1.0a6",
"async-exit-stack==1.0.1",
"noiseprotocol>=0.3.0",
"trio-typing>=0.0.4",
"exceptiongroup>=1.2.0; python_version < '3.11'",
# added during debugging
"anyio",
"p2pclient",

View File

@ -10,15 +10,15 @@ features are implemented in swarm
"""
import enum
from async_service import (
background_trio_service,
)
import pytest
import trio
from libp2p.network.notifee_interface import (
INotifee,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from libp2p.tools.constants import (
LISTEN_MADDR,
)

View File

@ -1,4 +1,7 @@
import pytest
from trio.testing import (
RaisesGroup,
)
from libp2p.host.exceptions import (
StreamFailure,
@ -58,7 +61,13 @@ async def test_single_protocol_succeeds(security_protocol):
@pytest.mark.trio
async def test_single_protocol_fails(security_protocol):
with pytest.raises(StreamFailure):
# using trio.testing.RaisesGroup b/c pytest.raises does not handle ExceptionGroups
# yet: https://github.com/pytest-dev/pytest/issues/11538
# but switch to that once they do
# the StreamFailure is within 2 nested ExceptionGroups, so we use strict=False
# to unwrap down to the core Exception
with RaisesGroup(StreamFailure, strict=False):
await perform_simple_test(
"", [PROTOCOL_ECHO], [PROTOCOL_POTATO], security_protocol
)
@ -96,7 +105,14 @@ async def test_multiple_protocol_second_is_valid_succeeds(security_protocol):
async def test_multiple_protocol_fails(security_protocol):
protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO, "/bar/1.0.0"]
protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"]
with pytest.raises(StreamFailure):
# using trio.testing.RaisesGroup b/c pytest.raises does not handle ExceptionGroups
# yet: https://github.com/pytest-dev/pytest/issues/11538
# but switch to that once they do
# the StreamFailure is within 2 nested ExceptionGroups, so we use strict=False
# to unwrap down to the core Exception
with RaisesGroup(StreamFailure, strict=False):
await perform_simple_test(
"", protocols_for_client, protocols_for_listener, security_protocol
)

View File

@ -0,0 +1,668 @@
import sys
if sys.version_info >= (3, 11):
from builtins import (
ExceptionGroup,
)
else:
from exceptiongroup import (
ExceptionGroup,
)
import pytest
import trio
from trio.testing import (
Matcher,
RaisesGroup,
)
from libp2p.tools.async_service import (
DaemonTaskExit,
LifecycleError,
Service,
TrioManager,
as_service,
background_trio_service,
)
class WaitCancelledService(Service):
async def run(self) -> None:
await self.manager.wait_finished()
async def do_service_lifecycle_check(
manager, manager_run_fn, trigger_exit_condition_fn, should_be_cancelled
):
async with trio.open_nursery() as nursery:
assert manager.is_started is False
assert manager.is_running is False
assert manager.is_cancelled is False
assert manager.is_finished is False
nursery.start_soon(manager_run_fn)
with trio.fail_after(0.1):
await manager.wait_started()
assert manager.is_started is True
assert manager.is_running is True
assert manager.is_cancelled is False
assert manager.is_finished is False
# trigger the service to exit
trigger_exit_condition_fn()
with trio.fail_after(0.1):
await manager.wait_finished()
if should_be_cancelled:
assert manager.is_started is True
# We cannot determine whether the service should be running at this
# stage because a service is considered running until it is
# finished. Since it may be cancelled but still not finished we
# can't know.
assert manager.is_cancelled is True
# We also cannot determine whether a service should be finished at this
# stage as it could have exited cleanly and is now finished or it
# might be doing some cleanup after which it will register as being
# finished.
assert manager.is_running is True or manager.is_finished is True
assert manager.is_started is True
assert manager.is_running is False
assert manager.is_cancelled is should_be_cancelled
assert manager.is_finished is True
def test_service_manager_initial_state():
service = WaitCancelledService()
manager = TrioManager(service)
assert manager.is_started is False
assert manager.is_running is False
assert manager.is_cancelled is False
assert manager.is_finished is False
@pytest.mark.trio
async def test_trio_service_lifecycle_run_and_clean_exit():
trigger_exit = trio.Event()
@as_service
async def ServiceTest(manager):
await trigger_exit.wait()
service = ServiceTest()
manager = TrioManager(service)
await do_service_lifecycle_check(
manager=manager,
manager_run_fn=manager.run,
trigger_exit_condition_fn=trigger_exit.set,
should_be_cancelled=False,
)
@pytest.mark.trio
async def test_trio_service_lifecycle_run_and_external_cancellation():
@as_service
async def ServiceTest(manager):
await trio.sleep_forever()
service = ServiceTest()
manager = TrioManager(service)
await do_service_lifecycle_check(
manager=manager,
manager_run_fn=manager.run,
trigger_exit_condition_fn=manager.cancel,
should_be_cancelled=True,
)
@pytest.mark.trio
async def test_trio_service_lifecycle_run_and_exception():
trigger_error = trio.Event()
@as_service
async def ServiceTest(manager):
await trigger_error.wait()
raise RuntimeError("Service throwing error")
service = ServiceTest()
manager = TrioManager(service)
async def do_service_run():
with RaisesGroup(
Matcher(RuntimeError, match="Service throwing error"), strict=False
):
await manager.run()
await do_service_lifecycle_check(
manager=manager,
manager_run_fn=do_service_run,
trigger_exit_condition_fn=trigger_error.set,
should_be_cancelled=True,
)
@pytest.mark.trio
async def test_trio_service_lifecycle_run_and_task_exception():
trigger_error = trio.Event()
@as_service
async def ServiceTest(manager):
async def task_fn():
await trigger_error.wait()
raise RuntimeError("Service throwing error")
manager.run_task(task_fn)
service = ServiceTest()
manager = TrioManager(service)
async def do_service_run():
with RaisesGroup(
Matcher(RuntimeError, match="Service throwing error"), strict=False
):
await manager.run()
await do_service_lifecycle_check(
manager=manager,
manager_run_fn=do_service_run,
trigger_exit_condition_fn=trigger_error.set,
should_be_cancelled=True,
)
@pytest.mark.trio
async def test_sub_service_cancelled_when_parent_stops():
ready_cancel = trio.Event()
# This test runs a service that runs a sub-service that sleeps forever. When the
# parent exits, the sub-service should be cancelled as well.
@as_service
async def WaitForeverService(manager):
ready_cancel.set()
await manager.wait_finished()
sub_manager = TrioManager(WaitForeverService())
@as_service
async def ServiceTest(manager):
async def run_sub():
await sub_manager.run()
manager.run_task(run_sub)
await manager.wait_finished()
s = ServiceTest()
async with background_trio_service(s) as manager:
await ready_cancel.wait()
assert not manager.is_running
assert manager.is_cancelled
assert manager.is_finished
assert not sub_manager.is_running
assert not sub_manager.is_cancelled
assert sub_manager.is_finished
@pytest.mark.trio
async def test_trio_service_lifecycle_run_and_daemon_task_exit():
trigger_error = trio.Event()
@as_service
async def ServiceTest(manager):
async def daemon_task_fn():
await trigger_error.wait()
manager.run_daemon_task(daemon_task_fn)
await manager.wait_finished()
service = ServiceTest()
manager = TrioManager(service)
async def do_service_run():
with RaisesGroup(Matcher(DaemonTaskExit, match="Daemon task"), strict=False):
await manager.run()
await do_service_lifecycle_check(
manager=manager,
manager_run_fn=do_service_run,
trigger_exit_condition_fn=trigger_error.set,
should_be_cancelled=True,
)
@pytest.mark.trio
async def test_exceptiongroup_in_run():
# This test should cause TrioManager.run() to explicitly raise an ExceptionGroup
# containing two exceptions -- one raised inside its run() method and another
# raised by the daemon task exiting early.
trigger_error = trio.Event()
class ServiceTest(Service):
async def run(self):
ready = trio.Event()
self.manager.run_task(self.error_fn, ready)
await ready.wait()
trigger_error.set()
raise RuntimeError("Exception inside Service.run()")
async def error_fn(self, ready):
ready.set()
await trigger_error.wait()
raise ValueError("Exception inside error_fn")
with pytest.raises(ExceptionGroup) as exc_info:
await TrioManager.run_service(ServiceTest())
exc = exc_info.value
assert len(exc.exceptions) == 2
assert any(isinstance(err, RuntimeError) for err in exc.exceptions)
assert any(isinstance(err, ValueError) for err in exc.exceptions)
@pytest.mark.trio
async def test_trio_service_background_service_context_manager():
service = WaitCancelledService()
async with background_trio_service(service) as manager:
# ensure the manager property is set.
assert hasattr(service, "manager")
assert service.get_manager() is manager
assert manager.is_started is True
assert manager.is_running is True
assert manager.is_cancelled is False
assert manager.is_finished is False
assert manager.is_started is True
assert manager.is_running is False
assert manager.is_cancelled is True
assert manager.is_finished is True
@pytest.mark.trio
async def test_trio_service_manager_stop():
service = WaitCancelledService()
async with background_trio_service(service) as manager:
assert manager.is_started is True
assert manager.is_running is True
assert manager.is_cancelled is False
assert manager.is_finished is False
await manager.stop()
assert manager.is_started is True
assert manager.is_running is False
assert manager.is_cancelled is True
assert manager.is_finished is True
@pytest.mark.trio
async def test_trio_service_manager_run_task():
task_event = trio.Event()
@as_service
async def RunTaskService(manager):
async def task_fn():
task_event.set()
manager.run_task(task_fn)
await manager.wait_finished()
async with background_trio_service(RunTaskService()):
with trio.fail_after(0.1):
await task_event.wait()
@pytest.mark.trio
async def test_trio_service_manager_run_task_waits_for_task_completion():
task_event = trio.Event()
@as_service
async def RunTaskService(manager):
async def task_fn():
await trio.sleep(0.01)
task_event.set()
manager.run_task(task_fn)
# the task is set to run in the background but then the service exits.
# We want to be sure that the task is allowed to continue till
# completion unless explicitely cancelled.
async with background_trio_service(RunTaskService()):
with trio.fail_after(0.1):
await task_event.wait()
@pytest.mark.trio
async def test_trio_service_manager_run_task_can_still_cancel_after_run_finishes():
task_event = trio.Event()
service_finished = trio.Event()
@as_service
async def RunTaskService(manager):
async def task_fn():
# this will never complete
await task_event.wait()
manager.run_task(task_fn)
# the task is set to run in the background but then the service exits.
# We want to be sure that the task is allowed to continue till
# completion unless explicitely cancelled.
service_finished.set()
async with background_trio_service(RunTaskService()) as manager:
with trio.fail_after(0.01):
await service_finished.wait()
# show that the service hangs waiting for the task to complete.
with trio.move_on_after(0.01) as cancel_scope:
await manager.wait_finished()
assert cancel_scope.cancelled_caught is True
# trigger cancellation and see that the service actually stops
manager.cancel()
with trio.fail_after(0.01):
await manager.wait_finished()
@pytest.mark.trio
async def test_trio_service_manager_run_task_reraises_exceptions():
task_event = trio.Event()
@as_service
async def RunTaskService(manager):
async def task_fn():
await task_event.wait()
raise Exception("task exception in run_task")
manager.run_task(task_fn)
with trio.fail_after(1):
await trio.sleep_forever()
with RaisesGroup(
Matcher(Exception, match="task exception in run_task"), strict=False
):
async with background_trio_service(RunTaskService()):
task_event.set()
with trio.fail_after(1):
await trio.sleep_forever()
@pytest.mark.trio
async def test_trio_service_manager_run_daemon_task_cancels_if_exits():
task_event = trio.Event()
@as_service
async def RunTaskService(manager):
async def daemon_task_fn():
await task_event.wait()
manager.run_daemon_task(daemon_task_fn, name="daemon_task_fn")
with trio.fail_after(1):
await trio.sleep_forever()
with RaisesGroup(
Matcher(
DaemonTaskExit, match=r"Daemon task daemon_task_fn\[daemon=True\] exited"
),
strict=False,
):
async with background_trio_service(RunTaskService()):
task_event.set()
with trio.fail_after(1):
await trio.sleep_forever()
@pytest.mark.trio
async def test_trio_service_manager_propogates_and_records_exceptions():
@as_service
async def ThrowErrorService(manager):
raise RuntimeError("this is the error")
service = ThrowErrorService()
manager = TrioManager(service)
assert manager.did_error is False
with RaisesGroup(Matcher(RuntimeError, match="this is the error"), strict=False):
await manager.run()
assert manager.did_error is True
@pytest.mark.trio
async def test_trio_service_lifecycle_run_and_clean_exit_with_child_service():
trigger_exit = trio.Event()
@as_service
async def ChildServiceTest(manager):
await trigger_exit.wait()
@as_service
async def ServiceTest(manager):
child_manager = manager.run_child_service(ChildServiceTest())
await child_manager.wait_started()
service = ServiceTest()
manager = TrioManager(service)
await do_service_lifecycle_check(
manager=manager,
manager_run_fn=manager.run,
trigger_exit_condition_fn=trigger_exit.set,
should_be_cancelled=False,
)
@pytest.mark.trio
async def test_trio_service_with_daemon_child_service():
ready = trio.Event()
@as_service
async def ChildServiceTest(manager):
await manager.wait_finished()
@as_service
async def ServiceTest(manager):
child_manager = manager.run_daemon_child_service(ChildServiceTest())
await child_manager.wait_started()
ready.set()
await manager.wait_finished()
service = ServiceTest()
async with background_trio_service(service):
await ready.wait()
@pytest.mark.trio
async def test_trio_service_with_daemon_child_task():
ready = trio.Event()
started = trio.Event()
async def _task():
started.set()
await trio.sleep(100)
@as_service
async def ServiceTest(manager):
manager.run_daemon_task(_task)
await started.wait()
ready.set()
await manager.wait_finished()
service = ServiceTest()
async with background_trio_service(service):
await ready.wait()
@pytest.mark.trio
async def test_trio_service_with_async_generator():
is_within_agen = trio.Event()
async def do_agen():
while True:
yield
@as_service
async def ServiceTest(manager):
async for _ in do_agen(): # noqa: F841
await trio.lowlevel.checkpoint()
is_within_agen.set()
async with background_trio_service(ServiceTest()) as manager:
await is_within_agen.wait()
manager.cancel()
@pytest.mark.trio
async def test_trio_service_disallows_task_scheduling_when_not_running():
class ServiceTest(Service):
async def run(self):
await self.manager.wait_finished()
def do_schedule(self):
self.manager.run_task(trio.sleep, 1)
service = ServiceTest()
async with background_trio_service(service):
service.do_schedule()
with pytest.raises(LifecycleError):
service.do_schedule()
@pytest.mark.trio
async def test_trio_service_disallows_task_scheduling_after_cancel():
@as_service
async def ServiceTest(manager):
manager.cancel()
manager.run_task(trio.sleep, 1)
await TrioManager.run_service(ServiceTest())
@pytest.mark.trio
async def test_trio_service_cancellation_with_running_daemon_task():
in_daemon = trio.Event()
class ServiceTest(Service):
async def run(self):
self.manager.run_daemon_task(self._do_daemon)
await self.manager.wait_finished()
async def _do_daemon(self):
in_daemon.set()
while self.manager.is_running:
await trio.lowlevel.checkpoint()
async with background_trio_service(ServiceTest()) as manager:
await in_daemon.wait()
manager.cancel()
@pytest.mark.trio
async def test_trio_service_with_try_finally_cleanup():
ready_cancel = trio.Event()
class TryFinallyService(Service):
cleanup_up = False
async def run(self) -> None:
try:
ready_cancel.set()
await self.manager.wait_finished()
finally:
self.cleanup_up = True
service = TryFinallyService()
async with background_trio_service(service) as manager:
await ready_cancel.wait()
assert not service.cleanup_up
manager.cancel()
assert service.cleanup_up
@pytest.mark.trio
async def test_trio_service_with_try_finally_cleanup_with_unshielded_await():
ready_cancel = trio.Event()
class TryFinallyService(Service):
cleanup_up = False
async def run(self) -> None:
try:
ready_cancel.set()
await self.manager.wait_finished()
finally:
await trio.lowlevel.checkpoint()
self.cleanup_up = True
service = TryFinallyService()
async with background_trio_service(service) as manager:
await ready_cancel.wait()
assert not service.cleanup_up
manager.cancel()
assert not service.cleanup_up
@pytest.mark.trio
async def test_trio_service_with_try_finally_cleanup_with_shielded_await():
ready_cancel = trio.Event()
class TryFinallyService(Service):
cleanup_up = False
async def run(self) -> None:
try:
ready_cancel.set()
await self.manager.wait_finished()
finally:
with trio.CancelScope(shield=True):
await trio.lowlevel.checkpoint()
self.cleanup_up = True
service = TryFinallyService()
async with background_trio_service(service) as manager:
await ready_cancel.wait()
assert not service.cleanup_up
manager.cancel()
assert service.cleanup_up
@pytest.mark.trio
async def test_error_in_service_run():
class ServiceTest(Service):
async def run(self):
self.manager.run_daemon_task(self.manager.wait_finished)
raise ValueError("Exception inside run()")
with RaisesGroup(ValueError, strict=False):
await TrioManager.run_service(ServiceTest())
@pytest.mark.trio
async def test_daemon_task_finishes_leaving_children():
class ServiceTest(Service):
async def sleep_and_fail(self):
await trio.sleep(1)
raise AssertionError(
"This should not happen as the task should be cancelled"
)
async def buggy_daemon(self):
self.manager.run_task(self.sleep_and_fail)
async def run(self):
self.manager.run_daemon_task(self.buggy_daemon)
with RaisesGroup(DaemonTaskExit, strict=False):
await TrioManager.run_service(ServiceTest())

View File

@ -0,0 +1,109 @@
# Copied from https://github.com/ethereum/async-service
import pytest
import trio
from trio.testing import (
RaisesGroup,
)
from libp2p.tools.async_service import (
LifecycleError,
Service,
background_trio_service,
)
from libp2p.tools.async_service.trio_service import (
external_api,
)
class ExternalAPIService(Service):
async def run(self):
await self.manager.wait_finished()
@external_api
async def get_7(self, wait_return=None, signal_event=None):
if signal_event is not None:
signal_event.set()
if wait_return is not None:
await wait_return.wait()
return 7
@pytest.mark.trio
async def test_trio_service_external_api_fails_before_start():
service = ExternalAPIService()
# should raise if the service has not yet been started.
with pytest.raises(LifecycleError):
await service.get_7()
@pytest.mark.trio
async def test_trio_service_external_api_works_while_running():
service = ExternalAPIService()
async with background_trio_service(service):
assert await service.get_7() == 7
@pytest.mark.trio
async def test_trio_service_external_api_raises_when_cancelled():
service = ExternalAPIService()
async with background_trio_service(service) as manager:
with RaisesGroup(LifecycleError, strict=False):
async with trio.open_nursery() as nursery:
# an event to ensure that we are indeed within the body of the
is_within_fn = trio.Event()
trigger_return = trio.Event()
nursery.start_soon(service.get_7, trigger_return, is_within_fn)
# ensure we're within the body of the task.
await is_within_fn.wait()
# now cancel the service and trigger the return of the function.
manager.cancel()
# exiting the context block here will cause the background task
# to complete and shold raise the exception
# A direct call should also fail. This *should* be hitting the early
# return mechanism.
with pytest.raises(LifecycleError):
assert await service.get_7()
@pytest.mark.trio
async def test_trio_service_external_api_raises_when_finished():
service = ExternalAPIService()
async with background_trio_service(service) as manager:
pass
assert manager.is_finished
# A direct call should also fail. This *should* be hitting the early
# return mechanism.
with pytest.raises(LifecycleError):
assert await service.get_7()
@pytest.mark.trio
async def test_trio_external_api_call_that_schedules_task():
done = trio.Event()
class MyService(Service):
async def run(self):
await self.manager.wait_finished()
@external_api
async def do_scheduling(self):
self.manager.run_task(self.set_done)
async def set_done(self):
done.set()
service = MyService()
async with background_trio_service(service):
await service.do_scheduling()
with trio.fail_after(1):
await done.wait()

View File

@ -0,0 +1,86 @@
import pytest
import trio
from libp2p.tools.async_service import (
Service,
background_trio_service,
)
@pytest.mark.trio
async def test_trio_manager_stats():
ready = trio.Event()
class StatsTest(Service):
async def run(self):
# 2 that run forever
self.manager.run_task(trio.sleep_forever)
self.manager.run_task(trio.sleep_forever)
# 2 that complete
self.manager.run_task(trio.lowlevel.checkpoint)
self.manager.run_task(trio.lowlevel.checkpoint)
# 1 that spawns some children
self.manager.run_task(self.run_with_children, 4)
async def run_with_children(self, num_children):
for _ in range(num_children):
self.manager.run_task(trio.sleep_forever)
ready.set()
def run_external_root(self):
self.manager.run_task(trio.lowlevel.checkpoint)
service = StatsTest()
async with background_trio_service(service) as manager:
service.run_external_root()
assert len(manager._root_tasks) == 2
with trio.fail_after(1):
await ready.wait()
# we need to yield to the event loop a few times to allow the various
# tasks to schedule themselves and get running.
for _ in range(50):
await trio.lowlevel.checkpoint()
assert manager.stats.tasks.total_count == 10
assert manager.stats.tasks.finished_count == 3
assert manager.stats.tasks.pending_count == 7
# This is a simple test to ensure that finished tasks are removed from
# tracking to prevent unbounded memory growth.
assert len(manager._root_tasks) == 1
# now check after exiting
assert manager.stats.tasks.total_count == 10
assert manager.stats.tasks.finished_count == 10
assert manager.stats.tasks.pending_count == 0
@pytest.mark.trio
async def test_trio_manager_stats_does_not_count_main_run_method():
ready = trio.Event()
class StatsTest(Service):
async def run(self):
self.manager.run_task(trio.sleep_forever)
ready.set()
async with background_trio_service(StatsTest()) as manager:
with trio.fail_after(1):
await ready.wait()
# we need to yield to the event loop a few times to allow the various
# tasks to schedule themselves and get running.
for _ in range(10):
await trio.lowlevel.checkpoint()
assert manager.stats.tasks.total_count == 1
assert manager.stats.tasks.finished_count == 0
assert manager.stats.tasks.pending_count == 1
# now check after exiting
assert manager.stats.tasks.total_count == 1
assert manager.stats.tasks.finished_count == 1
assert manager.stats.tasks.pending_count == 0