Source code for acapy_agent.utils.task_queue

"""Classes for managing a set of asyncio tasks."""

import asyncio
import logging
import time
from typing import Callable, Coroutine, Optional, Tuple

LOGGER = logging.getLogger(__name__)


[docs] def coro_ident(coro: Coroutine): """Extract an identifier for a coroutine.""" return coro and (hasattr(coro, "__qualname__") and coro.__qualname__ or repr(coro))
[docs] async def coro_timed(coro: Coroutine, timing: dict): """Capture timing for a coroutine.""" timing["started"] = time.perf_counter() try: return await coro finally: timing["ended"] = time.perf_counter()
[docs] def task_exc_info(task: asyncio.Task): """Extract exception info from an asyncio task.""" if not task or not task.done(): return try: exc_val = task.exception() except asyncio.CancelledError: exc_val = asyncio.CancelledError("Task was cancelled") if exc_val: return type(exc_val), exc_val, exc_val.__traceback__
[docs] class CompletedTask: """Represent the result of a queued task.""" def __init__( self, task: asyncio.Task, exc_info: Tuple, ident: Optional[str] = None, timing: Optional[dict] = None, ): """Initialize the completed task.""" self.exc_info = exc_info self.ident = ident self.task = task self.timing = timing def __repr__(self) -> str: """Generate string representation for logging.""" return f"<{self.__class__.__name__} ident={self.ident} timing={self.timing}>"
[docs] class PendingTask: """Represent a task in the queue.""" def __init__( self, coro: Coroutine, complete_hook: Optional[Callable] = None, ident: Optional[str] = None, task_future: asyncio.Future = None, queued_time: Optional[float] = None, ): """Initialize the pending task. Args: coro: The coroutine to be run complete_hook: A callback to run on completion ident: A string identifier for the task task_future: A future to be resolved to the asyncio Task queued_time: When the pending task was added to the queue """ if not asyncio.iscoroutine(coro): raise ValueError(f"Expected coroutine, got {coro}") self._cancelled = False self.complete_hook = complete_hook self.coro = coro self.queued_time: float = queued_time self.unqueued_time: Optional[float] = None self.ident = ident or coro_ident(coro) self.task_future = task_future or asyncio.get_event_loop().create_future()
[docs] def cancel(self): """Cancel the pending task.""" self.coro.close() if not self.task_future.done(): self.task_future.cancel() self._cancelled = True
@property def cancelled(self): """Accessor for the cancelled property.""" return self._cancelled @property def task(self) -> asyncio.Task: """Accessor for the task.""" return self.task_future.done() and self.task_future.result() @task.setter def task(self, task: asyncio.Task): """Setter for the task.""" if self.task_future.cancelled(): return elif self.task_future.done(): raise ValueError("Cannot set pending task future, already done") self.task_future.set_result(task) def __await__(self): """Wait for the task to be queued.""" return self.task_future.__await__() def __repr__(self) -> str: """Generate string representation for logging.""" return f"<{self.__class__.__name__} ident={self.ident}>"
[docs] class TaskQueue: """A class for managing a set of asyncio tasks.""" def __init__( self, max_active: int = 0, timed: bool = False, trace_fn: Optional[Callable] = None, ): """Initialize the task queue. Args: max_active: The maximum number of tasks to automatically run timed: A flag indicating that timing should be collected for tasks trace_fn: A callback for all completed tasks """ self.loop = None # Lazy initialization self.active_tasks: list[asyncio.Task] = [] self.pending_tasks: list[PendingTask] = [] self.timed = timed self.total_done = 0 self.total_failed = 0 self.total_started = 0 self._trace_fn = trace_fn self._cancelled = False self._drain_evt = None # Lazy initialization self._drain_task: asyncio.Task = None self._max_active = max_active def _ensure_loop(self): """Ensure the event loop is initialized.""" if self.loop is None: try: self.loop = asyncio.get_running_loop() except RuntimeError: # No running loop, try to get the event loop policy loop try: self.loop = asyncio.get_event_loop() except RuntimeError: # Create a new event loop if none exists self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) if self._drain_evt is None: self._drain_evt = asyncio.Event() @property def cancelled(self) -> bool: """Accessor for the cancelled property of the queue.""" return self._cancelled @property def max_active(self) -> int: """Accessor for the maximum number of active tasks in the queue.""" return self._max_active @property def ready(self) -> bool: """Accessor for the ready property of the queue.""" return ( not self._cancelled and not self._max_active or self.current_size < self._max_active ) @property def current_active(self) -> int: """Accessor for the current number of active tasks in the queue.""" return len(self.active_tasks) @property def current_pending(self) -> int: """Accessor for the current number of pending tasks in the queue.""" return len(self.pending_tasks) @property def current_size(self) -> int: """Accessor for the total number of tasks in the queue.""" return len(self.active_tasks) + len(self.pending_tasks) def __bool__(self) -> bool: """Support for the bool() builtin. Return: True - the task queue exists even if there are no tasks """ return True def __len__(self) -> int: """Support for the len() builtin.""" return self.current_size
[docs] def drain(self) -> asyncio.Task: """Start the process to run queued tasks.""" self._ensure_loop() # Ensure loop is initialized if self._drain_task and not self._drain_task.done(): self._drain_evt.set() elif self.pending_tasks: self._drain_task = self.loop.create_task(self._drain_loop()) self._drain_task.add_done_callback(lambda task: self._drain_done(task)) return self._drain_task
def _drain_done(self, task: asyncio.Task): """Handle completion of the drain process.""" exc_info = task_exc_info(task) if exc_info: LOGGER.exception("Error draining task queue:", exc_info=exc_info) if self._drain_task and self._drain_task.done(): self._drain_task = None async def _drain_loop(self): """Run pending tasks while there is room in the queue.""" # Note: this method should not call async methods apart from # waiting for the drain event, to avoid yielding to other queue methods while True: self._drain_evt.clear() while self.pending_tasks and ( not self._max_active or len(self.active_tasks) < self._max_active ): pending: PendingTask = self.pending_tasks.pop(0) if pending.queued_time: pending.unqueued_time = time.perf_counter() timing = { "queued": pending.queued_time, "unqueued": pending.unqueued_time, } else: timing = None task = self.run( pending.coro, pending.complete_hook, pending.ident, timing ) try: pending.task = task except ValueError: LOGGER.warning("Pending task future already fulfilled") if self.pending_tasks: await self._drain_evt.wait() else: break
[docs] def add_pending(self, pending: PendingTask): """Add a task to the pending queue. Args: pending: The `PendingTask` to add to the task queue """ if self.timed and not pending.queued_time: pending.queued_time = time.perf_counter() self.pending_tasks.append(pending) self.drain()
[docs] def add_active( self, task: asyncio.Task, task_complete: Optional[Callable] = None, ident: Optional[str] = None, timing: Optional[dict] = None, ) -> asyncio.Task: """Register an active async task with an optional completion callback. Args: task: The asyncio task instance task_complete: An optional callback to run on completion ident: A string identifier for the task timing: An optional dictionary of timing information """ self.active_tasks.append(task) task.add_done_callback( lambda fut: self.completed_task(task, task_complete, ident, timing) ) self.total_started += 1 return task
[docs] def run( self, coro: Coroutine, task_complete: Optional[Callable] = None, ident: Optional[str] = None, timing: Optional[dict] = None, ) -> asyncio.Task: """Start executing a coroutine as an async task, bypassing the pending queue. Args: coro: The coroutine to run task_complete: An optional callback to run on completion ident: A string identifier for the task timing: An optional dictionary of timing information Returns: the new asyncio task instance """ if self._cancelled: raise RuntimeError("Task queue has been cancelled") if not asyncio.iscoroutine(coro): raise ValueError(f"Expected coroutine, got {coro}") if not ident: ident = coro_ident(coro) if self.timed: if not timing: timing = {} coro = coro_timed(coro, timing) self._ensure_loop() # Ensure loop is initialized task = self.loop.create_task(coro) return self.add_active(task, task_complete, ident, timing)
[docs] def put( self, coro: Coroutine, task_complete: Optional[Callable] = None, ident: Optional[str] = None, ) -> PendingTask: """Add a new task to the queue, delaying execution if busy. Args: coro: The coroutine to run task_complete: A callback to run on completion ident: A string identifier for the task Returns: a future resolving to the asyncio task instance once queued """ pending = PendingTask(coro, task_complete, ident) if self._cancelled: pending.cancel() elif self.ready: pending.task = self.run(coro, task_complete, pending.ident) else: self.add_pending(pending) return pending
[docs] def completed_task( self, task: asyncio.Task, task_complete: Callable, ident: str, timing: Optional[dict] = None, ): """Clean up after a task has completed and run callbacks.""" exc_info = task_exc_info(task) if exc_info: self.total_failed += 1 if not task_complete and not self._trace_fn: LOGGER.exception("Error running task %s", ident or "", exc_info=exc_info) else: self.total_done += 1 if task_complete or self._trace_fn: completed = CompletedTask(task, exc_info, ident, timing) try: if task_complete: task_complete(completed) if self._trace_fn: self._trace_fn(completed) except Exception: LOGGER.exception("Error finalizing task %s", completed) try: self.active_tasks.remove(task) except ValueError: pass self.drain()
[docs] def cancel_pending(self): """Cancel any pending tasks in the queue.""" if self._drain_task: self._drain_task.cancel() self._drain_task = None for pending in self.pending_tasks: pending.cancel() self.pending_tasks = []
[docs] def cancel(self): """Cancel any pending or active tasks in the queue.""" self._cancelled = True self.cancel_pending() for task in self.active_tasks: if not task.done(): task.cancel()
[docs] async def complete(self, timeout: Optional[float] = None, cleanup: bool = True): """Cancel any pending tasks and wait for, or cancel active tasks.""" self._cancelled = True self.cancel_pending() if timeout or timeout is None: try: await self.wait_for(timeout) except asyncio.TimeoutError: pass for task in self.active_tasks: if not task.done(): task.cancel() if cleanup: while True: drain = self.drain() if not drain: break await drain
[docs] async def flush(self): """Wait for any active or pending tasks to be completed.""" self.drain() while self.active_tasks or self._drain_task: if self._drain_task: await self._drain_task if self.active_tasks: await asyncio.wait(self.active_tasks)
def __await__(self): """Handle the builtin await operator.""" yield from self.flush().__await__()
[docs] async def wait_for(self, timeout: float): """Wait for all queued tasks to complete with a timeout.""" return await asyncio.wait_for(self.flush(), timeout)
[docs] async def wait_for_completion(self): """Wait for all active tasks to complete with timeout. This is safer than flush() for testing as it doesn't try to manage the drain loop, just waits for existing tasks. """ if not self.active_tasks: return try: await asyncio.wait_for( asyncio.gather(*self.active_tasks, return_exceptions=True), timeout=5.0, ) except asyncio.TimeoutError: # Tasks didn't complete in time, but that's okay for testing pass