"""Inbound connection handling classes."""
import asyncio
import logging
from typing import Callable, Sequence, Union
from ...config.injection_context import InjectionContext
from ..error import WireFormatError
from ..outbound.message import OutboundMessage
from ..wire_format import BaseWireFormat
from .message import InboundMessage
from .receipt import MessageReceipt
LOGGER = logging.getLogger(__name__)
[docs]class AcceptResult:
"""Represent the result of accept_response."""
def __init__(self, accepted: bool, retry: bool = False):
"""Initialize the `AcceptResult` instance."""
self.accepted = accepted
self.retry = retry
def __bool__(self) -> bool:
"""Check if the result is true."""
return self.accepted
[docs]class InboundSession:
"""Track an open transport connection for direct routing of outbound messages."""
def __init__(
self,
*,
context: InjectionContext,
inbound_handler: Callable,
session_id: str,
wire_format: BaseWireFormat,
accept_undelivered: bool = False,
can_respond: bool = False,
client_info: dict = None,
close_handler: Callable = None,
reply_mode: str = None,
reply_thread_ids: Sequence[str] = None,
reply_verkeys: Sequence[str] = None,
transport_type: str = None,
):
"""Initialize the inbound session."""
self.context = context
self.inbound_handler = inbound_handler
self.session_id = session_id
self.wire_format = wire_format
self.accept_undelivered = accept_undelivered
self.client_info = client_info
self.close_handler = close_handler
self.response_buffer: OutboundMessage = None
self.response_event = asyncio.Event()
self.transport_type = transport_type
self._can_respond = can_respond
self._closed = False
self._reply_mode = None
self._reply_verkeys = None
self._reply_thread_ids = None
# call setters
self.reply_thread_ids = reply_thread_ids
self.reply_verkeys = reply_verkeys
self.reply_mode = reply_mode
@property
def can_respond(self) -> bool:
"""Accessor for the session can-respond state."""
return self._can_respond and not self._closed
@can_respond.setter
def can_respond(self, can_respond: bool):
"""Setter for the session can-respond state."""
self._can_respond = can_respond
@property
def closed(self) -> bool:
"""Accessor for the session closed state."""
return self._closed
[docs] def close(self):
"""Setter for the session closed state."""
self._closed = True
self.response_event.set() # end wait_response if blocked
if self.close_handler:
self.close_handler(self)
@property
def reply_mode(self) -> str:
"""Accessor for the session reply mode."""
return self._reply_mode
@reply_mode.setter
def reply_mode(self, mode: str):
"""Setter for the session reply mode."""
if mode not in (
MessageReceipt.REPLY_MODE_ALL,
MessageReceipt.REPLY_MODE_THREAD,
):
mode = None
self._reply_mode = mode
if not mode:
# reset the tracked thread IDs when the mode is changed to none
self.reply_thread_ids = set()
@property
def reply_verkeys(self):
"""Accessor for the reply verkeys."""
return self._reply_verkeys.copy()
@reply_verkeys.setter
def reply_verkeys(self, verkeys: Sequence[str]):
"""Setter for the reply verkeys."""
self._reply_verkeys = set(verkeys) if verkeys else set()
@property
def reply_thread_ids(self):
"""Accessor for the reply thread IDs."""
return self._reply_thread_ids.copy()
@reply_thread_ids.setter
def reply_thread_ids(self, thread_ids: Sequence[str]):
"""Setter for the reply thread IDs."""
self._reply_thread_ids = set(thread_ids) if thread_ids else set()
[docs] def add_reply_thread_ids(self, *thids):
"""Add a thread ID to the set of potential reply targets."""
for thid in filter(None, thids):
self._reply_thread_ids.add(thid)
[docs] def add_reply_verkeys(self, *verkeys):
"""Add a verkey to the set of potential reply targets."""
for verkey in filter(None, verkeys):
self._reply_verkeys.add(verkey)
@property
def response_buffered(self) -> bool:
"""Check if a response is currently buffered."""
return bool(self.response_buffer)
[docs] def process_inbound(self, message: InboundMessage):
"""
Process an incoming message and update the session metadata as necessary.
Args:
message: The inbound message instance
"""
receipt = message.receipt
mode = self.reply_mode = (
receipt.direct_response_requested and receipt.direct_response_mode
)
self.add_reply_verkeys(receipt.sender_verkey)
if mode == MessageReceipt.REPLY_MODE_THREAD:
self.add_reply_thread_ids(receipt.thread_id)
[docs] async def parse_inbound(self, payload_enc: Union[str, bytes]) -> InboundMessage:
"""Convert a message payload and to an inbound message."""
payload, receipt = await self.wire_format.parse_message(
self.context, payload_enc
)
return InboundMessage(
payload,
receipt,
session_id=self.session_id,
transport_type=self.transport_type,
)
[docs] async def receive(self, payload_enc: Union[str, bytes]) -> InboundMessage:
"""Receive a new message payload and dispatch the message."""
message = await self.parse_inbound(payload_enc)
self.receive_inbound(message)
return message
[docs] def receive_inbound(self, message: InboundMessage):
"""Deliver the inbound message to the conductor."""
self.process_inbound(message)
self.inbound_handler(message, can_respond=self.can_respond)
[docs] def select_outbound(self, message: OutboundMessage) -> bool:
"""Determine if an outbound message should be sent to this session.
Args:
message: The outbound message to be checked
"""
if not self.can_respond:
return False
mode = self.reply_mode
reply_verkey = message.reply_to_verkey
reply_thread_id = message.reply_thread_id
if reply_verkey and reply_verkey in self.reply_verkeys:
if mode == MessageReceipt.REPLY_MODE_ALL:
return True
elif (
mode == MessageReceipt.REPLY_MODE_THREAD
and reply_thread_id
and reply_thread_id in self._reply_thread_ids
):
return True
return False
[docs] async def encode_outbound(self, outbound: OutboundMessage) -> OutboundMessage:
"""Apply wire formatting to an outbound message."""
if not outbound.payload:
raise WireFormatError("Message has no payload to encode")
if not outbound.reply_to_verkey:
raise WireFormatError("No reply verkey available for encoding message")
return await self.wire_format.encode_message(
self.context,
outbound.payload,
[outbound.reply_to_verkey],
None,
outbound.reply_from_verkey,
)
[docs] def accept_response(self, message: OutboundMessage) -> AcceptResult:
"""
Try to queue an outbound message if it applies to this session.
Returns: a tuple of (message buffered, retry later)
"""
if not self.select_outbound(message):
return AcceptResult(False, False)
if self.response_buffer:
return AcceptResult(False, True)
self.set_response(message)
return AcceptResult(True)
[docs] def set_response(self, message: OutboundMessage):
"""Set the contents of the response message buffer."""
self.response_buffer = message
self.response_event.set()
[docs] def clear_response(self):
"""Handle when the buffered response message has been delivered."""
self.response_buffer = None
self.response_event.set()
[docs] async def wait_response(self) -> Union[str, bytes]:
"""Wait for a response to be buffered and pack it."""
while True:
if self._closed:
return
if self.response_buffer:
response = self.response_buffer.enc_payload
if not response:
try:
response = await self.encode_outbound(self.response_buffer)
except WireFormatError as e:
LOGGER.warning("Error encoding direct response: %s", str(e))
self.clear_response()
if response:
return response
self.response_event.clear()
await self.response_event.wait()
async def __aenter__(self):
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type, exc_value, exc_tb):
"""Async context manager entry."""
self.close()