"""Classes for managing a limited set of concurrent tasks."""
import asyncio
import logging
import time
from typing import Awaitable, Callable
LOGGER = logging.getLogger(__name__)
[docs]async def delay_task(delay: float, task: Awaitable):
"""Wait a given amount of time before executing an awaitable."""
await asyncio.sleep(delay)
return await task
[docs]class PendingTask:
"""Class for tracking pending tasks."""
def __init__(
self,
ident,
fn: Callable[["PendingTask"], Awaitable],
retries: int = None,
retry_delay: float = None,
):
"""Initialize the pending task instance."""
self.attempts = 0
self.ident = ident
self.fn = fn
self.future = asyncio.get_event_loop().create_future()
self.retries = retries
self.retry_delay = retry_delay
self.running: asyncio.Future = None
self.start = time.perf_counter()
[docs] def done(self):
"""Check if the task is done."""
return self.future.done()
[docs] def exception(self):
"""Get the exception raised by the task, if any."""
return self.future.exception()
[docs] def result(self):
"""Get the result of the task."""
return self.future.result()
[docs] def cancel(self):
"""Cancel the running task."""
if not self.future.done():
self.future.cancel()
if self.running and not self.running.done():
self.running.cancel()
def __await__(self):
"""Await the pending task."""
return self.future.__await__()
[docs]class TaskProcessor:
"""Class for managing a limited set of concurrent tasks."""
def __init__(self, *, max_pending: int = 10):
"""Instantiate the dispatcher."""
self.done_event = asyncio.Event()
self.done_event.set()
self.max_pending = max_pending
self.pending = set()
self.pending_lock = asyncio.Lock()
self.ready_event = asyncio.Event()
self.ready_event.set()
[docs] def ready(self):
"""Check if the processor is ready."""
return self.ready_event.is_set()
[docs] async def wait_ready(self):
"""Wait for the processor to be ready for more tasks."""
await self.ready_event.wait()
[docs] def done(self):
"""Check if the processor has any pending tasks."""
return self.done_event.is_set()
[docs] async def wait_done(self):
"""Wait for all pending tasks to complete."""
await self.done_event.wait()
def _enqueue_task(self, task: PendingTask):
"""Enqueue the given pending task."""
if not task.done():
awaitable = task.fn(task)
if awaitable:
if task.attempts and task.retry_delay:
awaitable = delay_task(task.retry_delay, awaitable)
task.attempts += 1
task.running = asyncio.ensure_future(awaitable)
task.running.add_done_callback(
lambda fut: asyncio.ensure_future(self._check_task(task))
)
else:
task.future.set_result(None)
asyncio.ensure_future(self._check_task(task))
async def _check_task(self, task: PendingTask):
"""Complete a task."""
if task.running and task.running.done():
future = task.running
task.running = None
exception = future.exception()
if exception:
LOGGER.debug(
"Task raised exception: (%s) %s", task.ident or task, exception
)
if task.retries and task.attempts < task.retries:
asyncio.get_event_loop().call_soon(self._enqueue_task, task)
else:
LOGGER.warning("Task failed: %s", task.ident or task)
task.future.set_exception(exception)
else:
task.future.set_result(future.result())
if task.done():
async with self.pending_lock:
if task in self.pending:
self.pending.remove(task)
else:
LOGGER.warning(
"Task not found in pending list: %s", task.ident or task
)
if len(self.pending) < self.max_pending:
self.ready_event.set()
if not self.pending:
self.done_event.set()
[docs] async def run_retry(
self,
fn: Callable[[PendingTask], Awaitable],
*,
ident=None,
retries: int = 5,
retry_delay: float = 10.0,
when_ready: bool = True,
) -> PendingTask:
"""Process a task and track the result."""
if when_ready:
await self.wait_ready()
task = PendingTask(ident, fn, retries=retries, retry_delay=retry_delay)
async with self.pending_lock:
self.pending.add(task)
self.done_event.clear()
if len(self.pending) >= self.max_pending:
self.ready_event.clear()
asyncio.get_event_loop().call_soon(self._enqueue_task, task)
return task
[docs] async def run_task(
self, task: Awaitable, *, ident=None, when_ready: bool = True
) -> PendingTask:
"""Run a single coroutine with no retries."""
return await self.run_retry(
lambda pending: task, ident=ident, retries=0, when_ready=when_ready
)