Source code for aries_cloudagent.core.dispatcher

"""
The Dispatcher.

The dispatcher is responsible for coordinating data flow between handlers, providing
lifecycle hook callbacks storing state for message threads, etc.
"""

import asyncio
import logging
import os
from typing import Callable, Coroutine, Union

from aiohttp.web import HTTPException

from ..config.injection_context import InjectionContext
from ..core.protocol_registry import ProtocolRegistry
from ..messaging.agent_message import AgentMessage
from ..messaging.error import MessageParseError
from ..messaging.models.base import BaseModelError
from ..messaging.request_context import RequestContext
from ..messaging.responder import BaseResponder
from ..messaging.util import datetime_now
from .error import ProtocolMinorVersionNotSupported

# FIXME: We shouldn't rely on a hardcoded message version here.
from ..protocols.connections.v1_0.manager import ConnectionManager
from ..protocols.problem_report.v1_0.message import ProblemReport

from ..transport.inbound.message import InboundMessage
from ..transport.outbound.message import OutboundMessage
from ..utils.stats import Collector
from ..utils.task_queue import CompletedTask, PendingTask, TaskQueue

from ..utils.tracing import trace_event, get_timer

LOGGER = logging.getLogger(__name__)


[docs]class Dispatcher: """ Dispatcher class. Class responsible for dispatching messages to message handlers and responding to other agents. """ def __init__(self, context: InjectionContext): """Initialize an instance of Dispatcher.""" self.context = context self.collector: Collector = None self.task_queue: TaskQueue = None
[docs] async def setup(self): """Perform async instance setup.""" self.collector = await self.context.inject(Collector, required=False) max_active = int(os.getenv("DISPATCHER_MAX_ACTIVE", 50)) self.task_queue = TaskQueue( max_active=max_active, timed=bool(self.collector), trace_fn=self.log_task )
[docs] def put_task( self, coro: Coroutine, complete: Callable = None, ident: str = None ) -> PendingTask: """Run a task in the task queue, potentially blocking other handlers.""" return self.task_queue.put(coro, complete, ident)
[docs] def run_task( self, coro: Coroutine, complete: Callable = None, ident: str = None ) -> asyncio.Task: """Run a task in the task queue, potentially blocking other handlers.""" return self.task_queue.run(coro, complete, ident)
[docs] def log_task(self, task: CompletedTask): """Log a completed task using the stats collector.""" if task.exc_info and not issubclass(task.exc_info[0], HTTPException): # skip errors intentionally returned to HTTP clients LOGGER.exception( "Handler error: %s", task.ident or "", exc_info=task.exc_info ) if self.collector: timing = task.timing if "queued" in timing: self.collector.log( f"Dispatcher:queued", timing["unqueued"] - timing["queued"] ) if task.ident: self.collector.log(task.ident, timing["ended"] - timing["started"])
[docs] def queue_message( self, inbound_message: InboundMessage, send_outbound: Coroutine, send_webhook: Coroutine = None, complete: Callable = None, ) -> PendingTask: """ Add a message to the processing queue for handling. Args: inbound_message: The inbound message instance send_outbound: Async function to send outbound messages send_webhook: Async function to dispatch a webhook complete: Function to call when the handler has completed Returns: A pending task instance resolving to the handler task """ return self.put_task( self.handle_message(inbound_message, send_outbound, send_webhook), complete )
[docs] async def handle_message( self, inbound_message: InboundMessage, send_outbound: Coroutine, send_webhook: Coroutine = None, ): """ Configure responder and message context and invoke the message handler. Args: inbound_message: The inbound message instance send_outbound: Async function to send outbound messages send_webhook: Async function to dispatch a webhook Returns: The response from the handler """ r_time = get_timer() connection_mgr = ConnectionManager(self.context) connection = await connection_mgr.find_inbound_connection( inbound_message.receipt ) if connection: inbound_message.connection_id = connection.connection_id error_result = None try: message = await self.make_message(inbound_message.payload) except MessageParseError as e: LOGGER.error(f"Message parsing failed: {str(e)}, sending problem report") error_result = ProblemReport(explain_ltxt=str(e)) if inbound_message.receipt.thread_id: error_result.assign_thread_id(inbound_message.receipt.thread_id) message = None trace_event( self.context.settings, message, outcome="Dispatcher.handle_message.START", ) context = RequestContext(base_context=self.context) context.message = message context.message_receipt = inbound_message.receipt context.connection_ready = connection and connection.is_ready context.connection_record = connection responder = DispatcherResponder( context, inbound_message, send_outbound, send_webhook, connection_id=connection and connection.connection_id, reply_session_id=inbound_message.session_id, reply_to_verkey=inbound_message.receipt.sender_verkey, ) if error_result: await responder.send_reply(error_result) return context.injector.bind_instance(BaseResponder, responder) handler_cls = context.message.Handler handler = handler_cls().handle if self.collector: handler = self.collector.wrap_coro(handler, [handler.__qualname__]) await handler(context, responder) trace_event( self.context.settings, context.message, outcome="Dispatcher.handle_message.END", perf_counter=r_time, )
[docs] async def make_message(self, parsed_msg: dict) -> AgentMessage: """ Deserialize a message dict into the appropriate message instance. Given a dict describing a message, this method returns an instance of the related message class. Args: parsed_msg: The parsed message Returns: An instance of the corresponding message class for this message Raises: MessageParseError: If the message doesn't specify @type MessageParseError: If there is no message class registered to handle the given type """ registry: ProtocolRegistry = await self.context.inject(ProtocolRegistry) message_type = parsed_msg.get("@type") if not message_type: raise MessageParseError("Message does not contain '@type' parameter") try: message_cls = registry.resolve_message_class(message_type) except ProtocolMinorVersionNotSupported as e: raise MessageParseError(f"Problem parsing message type. {e}") if not message_cls: raise MessageParseError(f"Unrecognized message type {message_type}") try: instance = message_cls.deserialize(parsed_msg) except BaseModelError as e: raise MessageParseError(f"Error deserializing message: {e}") from e return instance
[docs] async def complete(self, timeout: float = 0.1): """Wait for pending tasks to complete.""" await self.task_queue.complete(timeout=timeout)
[docs]class DispatcherResponder(BaseResponder): """Handle outgoing messages from message handlers.""" def __init__( self, context: RequestContext, inbound_message: InboundMessage, send_outbound: Coroutine, send_webhook: Coroutine = None, **kwargs, ): """ Initialize an instance of `DispatcherResponder`. Args: context: The request context of the incoming message inbound_message: The inbound message triggering this handler send_outbound: Async function to send outbound message send_webhook: Async function to dispatch a webhook """ super().__init__(**kwargs) self._context = context self._inbound_message = inbound_message self._send = send_outbound self._webhook = send_webhook
[docs] async def create_outbound( self, message: Union[AgentMessage, str, bytes], **kwargs ) -> OutboundMessage: """ Create an OutboundMessage from a message body. Args: message: The message payload """ if isinstance(message, AgentMessage) and self._context.settings.get( "timing.enabled" ): # Inject the timing decorator in_time = ( self._context.message_receipt and self._context.message_receipt.in_time ) if not message._decorators.get("timing"): message._decorators["timing"] = { "in_time": in_time, "out_time": datetime_now(), } return await super().create_outbound(message, **kwargs)
[docs] async def send_outbound(self, message: OutboundMessage): """ Send outbound message. Args: message: The `OutboundMessage` to be sent """ await self._send(self._context, message, self._inbound_message)
[docs] async def send_webhook(self, topic: str, payload: dict): """ Dispatch a webhook. Args: topic: the webhook topic identifier payload: the webhook payload value """ if self._webhook: await self._webhook(topic, payload)