Source code for aries_cloudagent.transport.inbound.session

"""Inbound connection handling classes."""

import asyncio
import logging
from typing import Callable, Sequence, Union

from ...config.injection_context import InjectionContext

from ..error import WireFormatError
from ..outbound.message import OutboundMessage
from ..wire_format import BaseWireFormat

from .message import InboundMessage
from .receipt import MessageReceipt

LOGGER = logging.getLogger(__name__)


[docs]class AcceptResult: """Represent the result of accept_response.""" def __init__(self, accepted: bool, retry: bool = False): """Initialize the `AcceptResult` instance.""" self.accepted = accepted self.retry = retry def __bool__(self) -> bool: """Check if the result is true.""" return self.accepted
[docs]class InboundSession: """Track an open transport connection for direct routing of outbound messages.""" def __init__( self, *, context: InjectionContext, inbound_handler: Callable, session_id: str, wire_format: BaseWireFormat, accept_undelivered: bool = False, can_respond: bool = False, client_info: dict = None, close_handler: Callable = None, reply_mode: str = None, reply_thread_ids: Sequence[str] = None, reply_verkeys: Sequence[str] = None, transport_type: str = None, ): """Initialize the inbound session.""" self.context = context self.inbound_handler = inbound_handler self.session_id = session_id self.wire_format = wire_format self.accept_undelivered = accept_undelivered self.client_info = client_info self.close_handler = close_handler self.response_buffer: OutboundMessage = None self.response_event = asyncio.Event() self.transport_type = transport_type self._can_respond = can_respond self._closed = False self._reply_mode = None self._reply_verkeys = None self._reply_thread_ids = None # call setters self.reply_thread_ids = reply_thread_ids self.reply_verkeys = reply_verkeys self.reply_mode = reply_mode @property def can_respond(self) -> bool: """Accessor for the session can-respond state.""" return self._can_respond and not self._closed @can_respond.setter def can_respond(self, can_respond: bool): """Setter for the session can-respond state.""" self._can_respond = can_respond @property def closed(self) -> bool: """Accessor for the session closed state.""" return self._closed
[docs] def close(self): """Setter for the session closed state.""" self._closed = True self.response_event.set() # end wait_response if blocked if self.close_handler: self.close_handler(self)
@property def reply_mode(self) -> str: """Accessor for the session reply mode.""" return self._reply_mode @reply_mode.setter def reply_mode(self, mode: str): """Setter for the session reply mode.""" if mode not in ( MessageReceipt.REPLY_MODE_ALL, MessageReceipt.REPLY_MODE_THREAD, ): mode = None self._reply_mode = mode if not mode: # reset the tracked thread IDs when the mode is changed to none self.reply_thread_ids = set() @property def reply_verkeys(self): """Accessor for the reply verkeys.""" return self._reply_verkeys.copy() @reply_verkeys.setter def reply_verkeys(self, verkeys: Sequence[str]): """Setter for the reply verkeys.""" self._reply_verkeys = set(verkeys) if verkeys else set() @property def reply_thread_ids(self): """Accessor for the reply thread IDs.""" return self._reply_thread_ids.copy() @reply_thread_ids.setter def reply_thread_ids(self, thread_ids: Sequence[str]): """Setter for the reply thread IDs.""" self._reply_thread_ids = set(thread_ids) if thread_ids else set()
[docs] def add_reply_thread_ids(self, *thids): """Add a thread ID to the set of potential reply targets.""" for thid in filter(None, thids): self._reply_thread_ids.add(thid)
[docs] def add_reply_verkeys(self, *verkeys): """Add a verkey to the set of potential reply targets.""" for verkey in filter(None, verkeys): self._reply_verkeys.add(verkey)
@property def response_buffered(self) -> bool: """Check if a response is currently buffered.""" return bool(self.response_buffer)
[docs] def process_inbound(self, message: InboundMessage): """ Process an incoming message and update the session metadata as necessary. Args: message: The inbound message instance """ receipt = message.receipt mode = self.reply_mode = ( receipt.direct_response_requested and receipt.direct_response_mode ) self.add_reply_verkeys(receipt.sender_verkey) if mode == MessageReceipt.REPLY_MODE_THREAD: self.add_reply_thread_ids(receipt.thread_id)
[docs] async def parse_inbound(self, payload_enc: Union[str, bytes]) -> InboundMessage: """Convert a message payload and to an inbound message.""" payload, receipt = await self.wire_format.parse_message( self.context, payload_enc ) return InboundMessage( payload, receipt, session_id=self.session_id, transport_type=self.transport_type, )
[docs] async def receive(self, payload_enc: Union[str, bytes]) -> InboundMessage: """Receive a new message payload and dispatch the message.""" message = await self.parse_inbound(payload_enc) self.receive_inbound(message) return message
[docs] def receive_inbound(self, message: InboundMessage): """Deliver the inbound message to the conductor.""" self.process_inbound(message) self.inbound_handler(message, can_respond=self.can_respond)
[docs] def select_outbound(self, message: OutboundMessage) -> bool: """Determine if an outbound message should be sent to this session. Args: message: The outbound message to be checked """ if not self.can_respond: return False mode = self.reply_mode reply_verkey = message.reply_to_verkey reply_thread_id = message.reply_thread_id if reply_verkey and reply_verkey in self.reply_verkeys: if mode == MessageReceipt.REPLY_MODE_ALL: return True elif ( mode == MessageReceipt.REPLY_MODE_THREAD and reply_thread_id and reply_thread_id in self._reply_thread_ids ): return True return False
[docs] async def encode_outbound(self, outbound: OutboundMessage) -> OutboundMessage: """Apply wire formatting to an outbound message.""" if not outbound.payload: raise WireFormatError("Message has no payload to encode") if not outbound.reply_to_verkey: raise WireFormatError("No reply verkey available for encoding message") return await self.wire_format.encode_message( self.context, outbound.payload, [outbound.reply_to_verkey], None, outbound.reply_from_verkey, )
[docs] def accept_response(self, message: OutboundMessage) -> AcceptResult: """ Try to queue an outbound message if it applies to this session. Returns: a tuple of (message buffered, retry later) """ if not self.select_outbound(message): return AcceptResult(False, False) if self.response_buffer: return AcceptResult(False, True) self.set_response(message) return AcceptResult(True)
[docs] def set_response(self, message: OutboundMessage): """Set the contents of the response message buffer.""" self.response_buffer = message self.response_event.set()
[docs] def clear_response(self): """Handle when the buffered response message has been delivered.""" self.response_buffer = None self.response_event.set()
[docs] async def wait_response(self) -> Union[str, bytes]: """Wait for a response to be buffered and pack it.""" while True: if self._closed: return if self.response_buffer: response = self.response_buffer.enc_payload if not response: try: response = await self.encode_outbound(self.response_buffer) except WireFormatError as e: LOGGER.warning("Error encoding direct response: %s", str(e)) self.clear_response() if response: return response self.response_event.clear() await self.response_event.wait()
async def __aenter__(self): """Async context manager entry.""" return self async def __aexit__(self, exc_type, exc_value, exc_tb): """Async context manager entry.""" self.close()