"""Outbound transport manager."""
import asyncio
import json
import logging
import time
from typing import Callable, Type, Union
from urllib.parse import urlparse
from ...connections.models.connection_target import ConnectionTarget
from ...config.injection_context import InjectionContext
from ...utils.classloader import ClassLoader, ModuleLoadError, ClassNotFoundError
from ...utils.stats import Collector
from ...utils.task_queue import CompletedTask, TaskQueue, task_exc_info
from ...utils.tracing import trace_event, get_timer
from ..wire_format import BaseWireFormat
from .base import (
BaseOutboundTransport,
OutboundDeliveryError,
OutboundTransportRegistrationError,
)
from .message import OutboundMessage
LOGGER = logging.getLogger(__name__)
MODULE_BASE_PATH = "aries_cloudagent.transport.outbound"
[docs]class QueuedOutboundMessage:
"""Class representing an outbound message pending delivery."""
STATE_NEW = "new"
STATE_PENDING = "pending"
STATE_ENCODE = "encode"
STATE_DELIVER = "deliver"
STATE_RETRY = "retry"
STATE_DONE = "done"
def __init__(
self,
context: InjectionContext,
message: OutboundMessage,
target: ConnectionTarget,
transport_id: str,
):
"""Initialize the queued outbound message."""
self.context = context
self.endpoint = target and target.endpoint
self.error: Exception = None
self.message = message
self.payload: Union[str, bytes] = None
self.retries = None
self.retry_at: float = None
self.state = self.STATE_NEW
self.target = target
self.task: asyncio.Task = None
self.transport_id: str = transport_id
[docs]class OutboundTransportManager:
"""Outbound transport manager class."""
def __init__(
self, context: InjectionContext, handle_not_delivered: Callable = None
):
"""
Initialize a `OutboundTransportManager` instance.
Args:
context: The application context
handle_not_delivered: An optional handler for undelivered messages
"""
self.context = context
self.loop = asyncio.get_event_loop()
self.handle_not_delivered = handle_not_delivered
self.outbound_buffer = []
self.outbound_event = asyncio.Event()
self.outbound_new = []
self.registered_schemes = {}
self.registered_transports = {}
self.running_transports = {}
self.task_queue = TaskQueue(max_active=200)
self._process_task: asyncio.Task = None
[docs] async def setup(self):
"""Perform setup operations."""
outbound_transports = (
self.context.settings.get("transport.outbound_configs") or []
)
for outbound_transport in outbound_transports:
self.register(outbound_transport)
[docs] def register(self, module: str) -> str:
"""
Register a new outbound transport by module path.
Args:
module: Module name to register
Raises:
OutboundTransportRegistrationError: If the imported class cannot
be located
OutboundTransportRegistrationError: If the imported class does not
specify a schemes attribute
OutboundTransportRegistrationError: If the scheme has already been
registered
"""
try:
imported_class = ClassLoader.load_subclass_of(
BaseOutboundTransport, module, MODULE_BASE_PATH
)
except (ModuleLoadError, ClassNotFoundError):
raise OutboundTransportRegistrationError(
f"Outbound transport module {module} could not be resolved."
)
return self.register_class(imported_class)
[docs] def register_class(
self, transport_class: Type[BaseOutboundTransport], transport_id: str = None
) -> str:
"""
Register a new outbound transport class.
Args:
transport_class: Transport class to register
Raises:
OutboundTransportRegistrationError: If the imported class does not
specify a schemes attribute
OutboundTransportRegistrationError: If the scheme has already been
registered
"""
try:
schemes = transport_class.schemes
except AttributeError:
raise OutboundTransportRegistrationError(
f"Imported class {transport_class} does not "
+ "specify a required 'schemes' attribute"
)
if not transport_id:
transport_id = transport_class.__qualname__
for scheme in schemes:
if scheme in self.registered_schemes:
# A scheme can only be registered once
raise OutboundTransportRegistrationError(
f"Cannot register transport '{transport_id}'"
f"for '{scheme}' scheme because the scheme"
"has already been registered"
)
self.registered_transports[transport_id] = transport_class
for scheme in schemes:
self.registered_schemes[scheme] = transport_id
return transport_id
[docs] async def start_transport(self, transport_id: str):
"""Start a registered transport."""
transport = self.registered_transports[transport_id]()
transport.collector = await self.context.inject(Collector, required=False)
await transport.start()
self.running_transports[transport_id] = transport
[docs] async def start(self):
"""Start all transports and feed messages from the queue."""
for transport_id in self.registered_transports:
self.task_queue.run(self.start_transport(transport_id))
[docs] async def stop(self, wait: bool = True):
"""Stop all running transports."""
if self._process_task and not self._process_task.done():
self._process_task.cancel()
await self.task_queue.complete(None if wait else 0)
for transport in self.running_transports.values():
await transport.stop()
self.running_transports = {}
[docs] def get_registered_transport_for_scheme(self, scheme: str) -> str:
"""Find the registered transport ID for a given scheme."""
try:
return next(
transport_id
for transport_id, transport in self.registered_transports.items()
if scheme in transport.schemes
)
except StopIteration:
pass
[docs] def get_running_transport_for_scheme(self, scheme: str) -> str:
"""Find the running transport ID for a given scheme."""
try:
return next(
transport_id
for transport_id, transport in self.running_transports.items()
if scheme in transport.schemes
)
except StopIteration:
pass
[docs] def get_running_transport_for_endpoint(self, endpoint: str):
"""Find the running transport ID to use for a given endpoint."""
# Grab the scheme from the uri
scheme = urlparse(endpoint).scheme
if scheme == "":
raise OutboundDeliveryError(
f"The uri '{endpoint}' does not specify a scheme"
)
# Look up transport that is registered to handle this scheme
transport_id = self.get_running_transport_for_scheme(scheme)
if not transport_id:
raise OutboundDeliveryError(
f"No transport driver exists to handle scheme '{scheme}'"
)
return transport_id
[docs] def get_transport_instance(self, transport_id: str) -> BaseOutboundTransport:
"""Get an instance of a running transport by ID."""
return self.running_transports[transport_id]
[docs] def enqueue_message(self, context: InjectionContext, outbound: OutboundMessage):
"""
Add an outbound message to the queue.
Args:
context: The context of the request
outbound: The outbound message to deliver
"""
targets = [outbound.target] if outbound.target else (outbound.target_list or [])
transport_id = None
for target in targets:
endpoint = target.endpoint
try:
transport_id = self.get_running_transport_for_endpoint(endpoint)
except OutboundDeliveryError:
pass
if transport_id:
break
if not transport_id:
raise OutboundDeliveryError("No supported transport for outbound message")
queued = QueuedOutboundMessage(context, outbound, target, transport_id)
queued.retries = 4
self.outbound_new.append(queued)
self.process_queued()
[docs] def enqueue_webhook(
self, topic: str, payload: dict, endpoint: str, max_attempts: int = None
):
"""
Add a webhook to the queue.
Args:
topic: The webhook topic
payload: The webhook payload
endpoint: The webhook endpoint
max_attempts: Override the maximum number of attempts
Raises:
OutboundDeliveryError: if the associated transport is not running
"""
transport_id = self.get_running_transport_for_endpoint(endpoint)
queued = QueuedOutboundMessage(None, None, None, transport_id)
queued.endpoint = f"{endpoint}/topic/{topic}/"
queued.payload = json.dumps(payload)
queued.state = QueuedOutboundMessage.STATE_PENDING
queued.retries = 4 if max_attempts is None else max_attempts - 1
self.outbound_new.append(queued)
self.process_queued()
[docs] def process_queued(self) -> asyncio.Task:
"""
Start the process to deliver queued messages if necessary.
Returns: the current queue processing task or None
"""
if self._process_task and not self._process_task.done():
self.outbound_event.set()
elif self.outbound_new or self.outbound_buffer:
self._process_task = self.loop.create_task(self._process_loop())
self._process_task.add_done_callback(lambda task: self._process_done(task))
return self._process_task
def _process_done(self, task: asyncio.Task):
"""Handle completion of the drain process."""
exc_info = task_exc_info(task)
if exc_info:
LOGGER.exception(
"Exception in outbound queue processing:", exc_info=exc_info
)
if self._process_task and self._process_task.done():
self._process_task = None
async def _process_loop(self):
"""Continually kick off encoding and delivery on outbound messages."""
# Note: this method should not call async methods apart from
# waiting for the updated event, to avoid yielding to other queue methods
while True:
self.outbound_event.clear()
loop_time = get_timer()
upd_buffer = []
for queued in self.outbound_buffer:
if queued.state == QueuedOutboundMessage.STATE_DONE:
if queued.error:
LOGGER.exception(
"Outbound message could not be delivered to %s",
queued.endpoint,
exc_info=queued.error,
)
if self.handle_not_delivered:
self.handle_not_delivered(queued.context, queued.message)
continue # remove from buffer
deliver = False
if queued.state == QueuedOutboundMessage.STATE_PENDING:
deliver = True
elif queued.state == QueuedOutboundMessage.STATE_RETRY:
if queued.retry_at < loop_time:
queued.retry_at = None
deliver = True
if deliver:
queued.state = QueuedOutboundMessage.STATE_DELIVER
p_time = trace_event(
self.context.settings,
queued.message if queued.message else queued.payload,
outcome="OutboundTransportManager.DELIVER.START."
+ queued.endpoint,
)
self.deliver_queued_message(queued)
trace_event(
self.context.settings,
queued.message if queued.message else queued.payload,
outcome="OutboundTransportManager.DELIVER.END."
+ queued.endpoint,
perf_counter=p_time,
)
upd_buffer.append(queued)
new_pending = 0
new_messages = self.outbound_new
self.outbound_new = []
for queued in new_messages:
if queued.state == QueuedOutboundMessage.STATE_NEW:
if queued.message and queued.message.enc_payload:
queued.payload = queued.message.enc_payload
queued.state = QueuedOutboundMessage.STATE_PENDING
new_pending += 1
else:
queued.state = QueuedOutboundMessage.STATE_ENCODE
p_time = trace_event(
self.context.settings,
queued.message if queued.message else queued.payload,
outcome="OutboundTransportManager.ENCODE.START",
)
self.encode_queued_message(queued)
trace_event(
self.context.settings,
queued.message if queued.message else queued.payload,
outcome="OutboundTransportManager.ENCODE.END",
perf_counter=p_time,
)
else:
new_pending += 1
upd_buffer.append(queued)
self.outbound_buffer = upd_buffer
if self.outbound_buffer:
if not new_pending:
await self.outbound_event.wait()
else:
break
[docs] def encode_queued_message(self, queued: QueuedOutboundMessage) -> asyncio.Task:
"""Kick off encoding of a queued message."""
queued.task = self.task_queue.run(
self.perform_encode(queued),
lambda completed: self.finished_encode(queued, completed),
)
return queued.task
[docs] def finished_encode(self, queued: QueuedOutboundMessage, completed: CompletedTask):
"""Handle completion of queued message encoding."""
if completed.exc_info:
queued.error = completed.exc_info
queued.state = QueuedOutboundMessage.STATE_DONE
else:
queued.state = QueuedOutboundMessage.STATE_PENDING
queued.task = None
self.process_queued()
[docs] def deliver_queued_message(self, queued: QueuedOutboundMessage) -> asyncio.Task:
"""Kick off delivery of a queued message."""
transport = self.get_transport_instance(queued.transport_id)
queued.task = self.task_queue.run(
transport.handle_message(queued.context, queued.payload, queued.endpoint),
lambda completed: self.finished_deliver(queued, completed),
)
return queued.task
[docs] def finished_deliver(self, queued: QueuedOutboundMessage, completed: CompletedTask):
"""Handle completion of queued message delivery."""
if completed.exc_info:
queued.error = completed.exc_info
if queued.retries:
LOGGER.error(
">>> Posting error: %s; Re-queue failed message ...",
queued.endpoint,
)
queued.retries -= 1
queued.state = QueuedOutboundMessage.STATE_RETRY
queued.retry_at = time.perf_counter() + 10
else:
LOGGER.exception(
"Outbound message could not be delivered", exc_info=queued.error,
)
LOGGER.error(">>> NOT Re-queued, state is DONE, failed to deliver msg.")
queued.state = QueuedOutboundMessage.STATE_DONE
else:
queued.error = None
queued.state = QueuedOutboundMessage.STATE_DONE
queued.task = None
self.process_queued()
[docs] async def flush(self):
"""Wait for any queued messages to be delivered."""
proc_task = self.process_queued()
if proc_task:
await proc_task