"""Classes for managing a set of asyncio tasks."""
import asyncio
import logging
import time
from typing import Callable, Coroutine, Optional, Tuple
LOGGER = logging.getLogger(__name__)
[docs]
def coro_ident(coro: Coroutine):
"""Extract an identifier for a coroutine."""
return coro and (hasattr(coro, "__qualname__") and coro.__qualname__ or repr(coro))
[docs]
async def coro_timed(coro: Coroutine, timing: dict):
"""Capture timing for a coroutine."""
timing["started"] = time.perf_counter()
try:
return await coro
finally:
timing["ended"] = time.perf_counter()
[docs]
def task_exc_info(task: asyncio.Task):
"""Extract exception info from an asyncio task."""
if not task or not task.done():
return
try:
exc_val = task.exception()
except asyncio.CancelledError:
exc_val = asyncio.CancelledError("Task was cancelled")
if exc_val:
return type(exc_val), exc_val, exc_val.__traceback__
[docs]
class CompletedTask:
"""Represent the result of a queued task."""
def __init__(
self,
task: asyncio.Task,
exc_info: Tuple,
ident: Optional[str] = None,
timing: Optional[dict] = None,
):
"""Initialize the completed task."""
self.exc_info = exc_info
self.ident = ident
self.task = task
self.timing = timing
def __repr__(self) -> str:
"""Generate string representation for logging."""
return f"<{self.__class__.__name__} ident={self.ident} timing={self.timing}>"
[docs]
class PendingTask:
"""Represent a task in the queue."""
def __init__(
self,
coro: Coroutine,
complete_hook: Optional[Callable] = None,
ident: Optional[str] = None,
task_future: asyncio.Future = None,
queued_time: Optional[float] = None,
):
"""Initialize the pending task.
Args:
coro: The coroutine to be run
complete_hook: A callback to run on completion
ident: A string identifier for the task
task_future: A future to be resolved to the asyncio Task
queued_time: When the pending task was added to the queue
"""
if not asyncio.iscoroutine(coro):
raise ValueError(f"Expected coroutine, got {coro}")
self._cancelled = False
self.complete_hook = complete_hook
self.coro = coro
self.queued_time: float = queued_time
self.unqueued_time: Optional[float] = None
self.ident = ident or coro_ident(coro)
self.task_future = task_future or asyncio.get_event_loop().create_future()
[docs]
def cancel(self):
"""Cancel the pending task."""
self.coro.close()
if not self.task_future.done():
self.task_future.cancel()
self._cancelled = True
@property
def cancelled(self):
"""Accessor for the cancelled property."""
return self._cancelled
@property
def task(self) -> asyncio.Task:
"""Accessor for the task."""
return self.task_future.done() and self.task_future.result()
@task.setter
def task(self, task: asyncio.Task):
"""Setter for the task."""
if self.task_future.cancelled():
return
elif self.task_future.done():
raise ValueError("Cannot set pending task future, already done")
self.task_future.set_result(task)
def __await__(self):
"""Wait for the task to be queued."""
return self.task_future.__await__()
def __repr__(self) -> str:
"""Generate string representation for logging."""
return f"<{self.__class__.__name__} ident={self.ident}>"
[docs]
class TaskQueue:
"""A class for managing a set of asyncio tasks."""
def __init__(
self,
max_active: int = 0,
timed: bool = False,
trace_fn: Optional[Callable] = None,
):
"""Initialize the task queue.
Args:
max_active: The maximum number of tasks to automatically run
timed: A flag indicating that timing should be collected for tasks
trace_fn: A callback for all completed tasks
"""
self.loop = None # Lazy initialization
self.active_tasks: list[asyncio.Task] = []
self.pending_tasks: list[PendingTask] = []
self.timed = timed
self.total_done = 0
self.total_failed = 0
self.total_started = 0
self._trace_fn = trace_fn
self._cancelled = False
self._drain_evt = None # Lazy initialization
self._drain_task: asyncio.Task = None
self._max_active = max_active
def _ensure_loop(self):
"""Ensure the event loop is initialized."""
if self.loop is None:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
# No running loop, try to get the event loop policy loop
try:
self.loop = asyncio.get_event_loop()
except RuntimeError:
# Create a new event loop if none exists
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
if self._drain_evt is None:
self._drain_evt = asyncio.Event()
@property
def cancelled(self) -> bool:
"""Accessor for the cancelled property of the queue."""
return self._cancelled
@property
def max_active(self) -> int:
"""Accessor for the maximum number of active tasks in the queue."""
return self._max_active
@property
def ready(self) -> bool:
"""Accessor for the ready property of the queue."""
return (
not self._cancelled
and not self._max_active
or self.current_size < self._max_active
)
@property
def current_active(self) -> int:
"""Accessor for the current number of active tasks in the queue."""
return len(self.active_tasks)
@property
def current_pending(self) -> int:
"""Accessor for the current number of pending tasks in the queue."""
return len(self.pending_tasks)
@property
def current_size(self) -> int:
"""Accessor for the total number of tasks in the queue."""
return len(self.active_tasks) + len(self.pending_tasks)
def __bool__(self) -> bool:
"""Support for the bool() builtin.
Return:
True - the task queue exists even if there are no tasks
"""
return True
def __len__(self) -> int:
"""Support for the len() builtin."""
return self.current_size
[docs]
def drain(self) -> asyncio.Task:
"""Start the process to run queued tasks."""
self._ensure_loop() # Ensure loop is initialized
if self._drain_task and not self._drain_task.done():
self._drain_evt.set()
elif self.pending_tasks:
self._drain_task = self.loop.create_task(self._drain_loop())
self._drain_task.add_done_callback(lambda task: self._drain_done(task))
return self._drain_task
def _drain_done(self, task: asyncio.Task):
"""Handle completion of the drain process."""
exc_info = task_exc_info(task)
if exc_info:
LOGGER.exception("Error draining task queue:", exc_info=exc_info)
if self._drain_task and self._drain_task.done():
self._drain_task = None
async def _drain_loop(self):
"""Run pending tasks while there is room in the queue."""
# Note: this method should not call async methods apart from
# waiting for the drain event, to avoid yielding to other queue methods
while True:
self._drain_evt.clear()
while self.pending_tasks and (
not self._max_active or len(self.active_tasks) < self._max_active
):
pending: PendingTask = self.pending_tasks.pop(0)
if pending.queued_time:
pending.unqueued_time = time.perf_counter()
timing = {
"queued": pending.queued_time,
"unqueued": pending.unqueued_time,
}
else:
timing = None
task = self.run(
pending.coro, pending.complete_hook, pending.ident, timing
)
try:
pending.task = task
except ValueError:
LOGGER.warning("Pending task future already fulfilled")
if self.pending_tasks:
await self._drain_evt.wait()
else:
break
[docs]
def add_pending(self, pending: PendingTask):
"""Add a task to the pending queue.
Args:
pending: The `PendingTask` to add to the task queue
"""
if self.timed and not pending.queued_time:
pending.queued_time = time.perf_counter()
self.pending_tasks.append(pending)
self.drain()
[docs]
def add_active(
self,
task: asyncio.Task,
task_complete: Optional[Callable] = None,
ident: Optional[str] = None,
timing: Optional[dict] = None,
) -> asyncio.Task:
"""Register an active async task with an optional completion callback.
Args:
task: The asyncio task instance
task_complete: An optional callback to run on completion
ident: A string identifier for the task
timing: An optional dictionary of timing information
"""
self.active_tasks.append(task)
task.add_done_callback(
lambda fut: self.completed_task(task, task_complete, ident, timing)
)
self.total_started += 1
return task
[docs]
def run(
self,
coro: Coroutine,
task_complete: Optional[Callable] = None,
ident: Optional[str] = None,
timing: Optional[dict] = None,
) -> asyncio.Task:
"""Start executing a coroutine as an async task, bypassing the pending queue.
Args:
coro: The coroutine to run
task_complete: An optional callback to run on completion
ident: A string identifier for the task
timing: An optional dictionary of timing information
Returns: the new asyncio task instance
"""
if self._cancelled:
raise RuntimeError("Task queue has been cancelled")
if not asyncio.iscoroutine(coro):
raise ValueError(f"Expected coroutine, got {coro}")
if not ident:
ident = coro_ident(coro)
if self.timed:
if not timing:
timing = {}
coro = coro_timed(coro, timing)
self._ensure_loop() # Ensure loop is initialized
task = self.loop.create_task(coro)
return self.add_active(task, task_complete, ident, timing)
[docs]
def put(
self,
coro: Coroutine,
task_complete: Optional[Callable] = None,
ident: Optional[str] = None,
) -> PendingTask:
"""Add a new task to the queue, delaying execution if busy.
Args:
coro: The coroutine to run
task_complete: A callback to run on completion
ident: A string identifier for the task
Returns: a future resolving to the asyncio task instance once queued
"""
pending = PendingTask(coro, task_complete, ident)
if self._cancelled:
pending.cancel()
elif self.ready:
pending.task = self.run(coro, task_complete, pending.ident)
else:
self.add_pending(pending)
return pending
[docs]
def completed_task(
self,
task: asyncio.Task,
task_complete: Callable,
ident: str,
timing: Optional[dict] = None,
):
"""Clean up after a task has completed and run callbacks."""
exc_info = task_exc_info(task)
if exc_info:
self.total_failed += 1
if not task_complete and not self._trace_fn:
LOGGER.exception("Error running task %s", ident or "", exc_info=exc_info)
else:
self.total_done += 1
if task_complete or self._trace_fn:
completed = CompletedTask(task, exc_info, ident, timing)
try:
if task_complete:
task_complete(completed)
if self._trace_fn:
self._trace_fn(completed)
except Exception:
LOGGER.exception("Error finalizing task %s", completed)
try:
self.active_tasks.remove(task)
except ValueError:
pass
self.drain()
[docs]
def cancel_pending(self):
"""Cancel any pending tasks in the queue."""
if self._drain_task:
self._drain_task.cancel()
self._drain_task = None
for pending in self.pending_tasks:
pending.cancel()
self.pending_tasks = []
[docs]
def cancel(self):
"""Cancel any pending or active tasks in the queue."""
self._cancelled = True
self.cancel_pending()
for task in self.active_tasks:
if not task.done():
task.cancel()
[docs]
async def complete(self, timeout: Optional[float] = None, cleanup: bool = True):
"""Cancel any pending tasks and wait for, or cancel active tasks."""
self._cancelled = True
self.cancel_pending()
if timeout or timeout is None:
try:
await self.wait_for(timeout)
except asyncio.TimeoutError:
pass
for task in self.active_tasks:
if not task.done():
task.cancel()
if cleanup:
while True:
drain = self.drain()
if not drain:
break
await drain
[docs]
async def flush(self):
"""Wait for any active or pending tasks to be completed."""
self.drain()
while self.active_tasks or self._drain_task:
if self._drain_task:
await self._drain_task
if self.active_tasks:
await asyncio.wait(self.active_tasks)
def __await__(self):
"""Handle the builtin await operator."""
yield from self.flush().__await__()
[docs]
async def wait_for(self, timeout: float):
"""Wait for all queued tasks to complete with a timeout."""
return await asyncio.wait_for(self.flush(), timeout)
[docs]
async def wait_for_completion(self):
"""Wait for all active tasks to complete with timeout.
This is safer than flush() for testing as it doesn't try to
manage the drain loop, just waits for existing tasks.
"""
if not self.active_tasks:
return
try:
await asyncio.wait_for(
asyncio.gather(*self.active_tasks, return_exceptions=True),
timeout=5.0,
)
except asyncio.TimeoutError:
# Tasks didn't complete in time, but that's okay for testing
pass