Source code for aries_cloudagent.messaging.agent_message

"""Agent message base class and schema."""

from collections import OrderedDict
from typing import Mapping, Optional, Text, Union
import uuid

from marshmallow import (
    EXCLUDE,
    ValidationError,
    fields,
    post_dump,
    post_load,
    pre_dump,
    pre_load,
)

from ..protocols.didcomm_prefix import DIDCommPrefix
from ..wallet.base import BaseWallet
from .base_message import BaseMessage, DIDCommVersion
from .decorators.base import BaseDecoratorSet
from .decorators.default import DecoratorSet
from .decorators.service_decorator import ServiceDecorator
from .decorators.signature_decorator import SignatureDecorator
from .decorators.thread_decorator import ThreadDecorator
from .decorators.trace_decorator import (
    TRACE_LOG_TARGET,
    TRACE_MESSAGE_TARGET,
    TraceDecorator,
    TraceReport,
)
from .message_type import MessageTypeStr
from .models.base import (
    BaseModel,
    BaseModelError,
    BaseModelSchema,
    resolve_class,
    resolve_meta_property,
)
from .valid import UUID4_EXAMPLE


[docs]class AgentMessageError(BaseModelError): """Base exception for agent message issues."""
[docs]class AgentMessage(BaseModel, BaseMessage): """Agent message base class."""
[docs] class Meta: """AgentMessage metadata.""" handler_class = None schema_class = None message_type = None
def __init__( self, _id: str = None, _type: Optional[Text] = None, _version: Optional[Text] = None, _decorators: BaseDecoratorSet = None, ): """Initialize base agent message object. Args: _id: Agent message id _decorators: Message decorators Raises: TypeError: If message type is missing on subclass Meta class """ super().__init__() if _id: self._message_id = _id self._message_new_id = False else: self._message_id = str(uuid.uuid4()) self._message_new_id = True self._message_decorators = ( _decorators if _decorators is not None else DecoratorSet() ) if not self.Meta.message_type: raise TypeError( "Can't instantiate abstract class {} with no message_type".format( self.__class__.__name__ ) ) self._message_type = MessageTypeStr( DIDCommPrefix.qualify_current(_type or self.Meta.message_type) ) if _version: self.assign_version(_version) @classmethod def _get_handler_class(cls): """Get handler class. Returns: The resolved class defined on `Meta.handler_class` """ return resolve_class(cls.Meta.handler_class, cls) @property def Handler(self) -> type: """Accessor for the agent message's handler class. Returns: Handler class """ return self._get_handler_class() @property def _type(self) -> MessageTypeStr: """Accessor for the message type identifier. Returns: Current DIDComm prefix, slash, message type defined on `Meta.message_type` """ return self._message_type @property def _id(self) -> str: """Accessor for the unique message identifier. Returns: The id of this message """ return self._message_id @_id.setter def _id(self, val: str): """Set the unique message identifier.""" self._message_id = val @property def _decorators(self) -> BaseDecoratorSet: """Fetch the message's decorator set.""" return self._message_decorators @property def _version(self) -> str: """Accessor for the message version.""" return str(self._type.version)
[docs] def assign_version_from(self, msg: "AgentMessage"): """Copy version information from a previous message. Args: msg: The received message containing version information to copy """ if msg: self.assign_version(msg._version)
[docs] def assign_version(self, version: str): """Assign a specific version. Args: version: The version to assign """ self._message_type = self._message_type.with_version(version)
@_decorators.setter def _decorators(self, value: BaseDecoratorSet): """Fetch the message's decorator set.""" self._message_decorators = value
[docs] def get_signature(self, field_name: str) -> SignatureDecorator: """Get the signature for a named field. Args: field_name: Field name to get the signature for Returns: A SignatureDecorator for the requested field name """ return self._decorators.field(field_name).get("sig")
[docs] def set_signature(self, field_name: str, signature: SignatureDecorator): """Add or replace the signature for a named field. Args: field_name: Field to set signature on signature: Signature for the field """ self._decorators.field(field_name)["sig"] = signature
[docs] async def sign_field( # TODO migrate to signed-attachment per RFC 17 self, field_name: str, signer_verkey: str, wallet: BaseWallet, timestamp=None ) -> SignatureDecorator: """Create and store a signature for a named field. Args: field_name: Field to sign signer_verkey: Verkey of signer wallet: Wallet to use for signature timestamp: Optional timestamp for signature Returns: A SignatureDecorator for newly created signature Raises: ValueError: If field_name doesn't exist on this message """ value = getattr(self, field_name, None) if value is None: raise BaseModelError( "{} field has no value for signature: {}".format( self.__class__.__name__, field_name ) ) sig = await SignatureDecorator.create(value, signer_verkey, wallet, timestamp) self.set_signature(field_name, sig) return sig
[docs] async def verify_signed_field( # TODO migrate to signed-attachment per RFC 17 self, field_name: str, wallet: BaseWallet, signer_verkey: str = None ) -> str: """Verify a specific field signature. Args: field_name: The field name to verify wallet: Wallet to use for the verification signer_verkey: Verkey of signer to use Returns: The verkey of the signer Raises: ValueError: If field_name does not exist on this message ValueError: If the verification fails ValueError: If the verkey of the signature does not match the provided verkey """ sig = self.get_signature(field_name) if not sig: raise BaseModelError("Missing field signature: {}".format(field_name)) if not await sig.verify(wallet): raise BaseModelError( "Field signature verification failed: {}".format(field_name) ) if signer_verkey is not None and sig.signer != signer_verkey: raise BaseModelError( "Signer verkey of signature does not match: {}".format(field_name) ) return sig.signer
[docs] async def verify_signatures(self, wallet: BaseWallet) -> bool: """Verify all associated field signatures. Args: wallet: Wallet to use in verification Returns: True if all signatures verify, else false """ for field in self._decorators.fields.values(): if "sig" in field and not await field["sig"].verify(wallet): return False return True
@property def _service(self) -> ServiceDecorator: """Accessor for the message's service decorator. Returns: The ServiceDecorator for this message """ return self._decorators.get("service") @_service.setter def _service(self, val: Union[ServiceDecorator, dict]): """Setter for the message's service decorator. Args: val: ServiceDecorator or dict to set as the service """ if val is None: self._decorators.pop("service", None) else: self._decorators["service"] = val @property def _thread(self) -> ThreadDecorator: """Accessor for the message's thread decorator. Returns: The ThreadDecorator for this message """ return self._decorators.get("thread") @_thread.setter def _thread(self, val: Union[ThreadDecorator, dict, None]): """Setter for the message's thread decorator. Args: val: ThreadDecorator or dict to set as the thread """ if val is None: self._decorators.pop("thread", None) else: self._decorators["thread"] = val @property def _thread_id(self) -> str: """Accessor for the ID associated with this message.""" if self._thread and self._thread.thid: return self._thread.thid return self._message_id
[docs] def assign_thread_from(self, msg: "AgentMessage"): """Copy thread information from a previous message. Args: msg: The received message containing optional thread information """ if msg: thread = msg._thread thid = thread and thread.thid or msg._message_id pthid = thread and thread.pthid self.assign_thread_id(thid, pthid)
[docs] def assign_thread_id(self, thid: Optional[str] = None, pthid: Optional[str] = None): """Assign a specific thread ID. Args: thid: The thread identifier pthid: The parent thread identifier """ if thid or pthid: self._thread = ThreadDecorator(thid=thid, pthid=pthid) else: self._thread = None
@property def _trace(self) -> TraceDecorator: """Accessor for the message's trace decorator. Returns: The TraceDecorator for this message """ return self._decorators.get("trace") @_trace.setter def _trace(self, val: Union[TraceDecorator, dict]): """Setter for the message's trace decorator. Args: val: TraceDecorator or dict to set as the trace """ if val is None: self._decorators.pop("trace", None) else: self._decorators["trace"] = val
[docs] def assign_trace_from(self, msg: "AgentMessage"): """Copy trace information from a previous message. Args: msg: The received message containing optional trace information """ if msg and msg._trace: # ignore if not a valid type if isinstance(msg._trace, TraceDecorator) or isinstance(msg._trace, dict): self._trace = msg._trace
[docs] def assign_trace_decorator(self, context, trace): """Copy trace from a json structure. Args: trace: string containing trace json structure """ if trace: self.add_trace_decorator( target=context.get("trace.target") if context else TRACE_LOG_TARGET, full_thread=True, )
[docs] def add_trace_decorator( self, target: str = TRACE_LOG_TARGET, full_thread: bool = True ): """Create a new trace decorator. Args: target: The trace target full_thread: Full thread flag """ if self._trace: # don't replace if there is already a trace decorator # (potentially holding trace reports already) self._trace._target = target self._trace._full_thread = full_thread else: self._trace = TraceDecorator(target=target, full_thread=full_thread)
[docs] def add_trace_report(self, val: Union[TraceReport, dict]): """Append a new trace report. Args: val: The trace target """ if not self._trace: self.add_trace_decorator(target=TRACE_MESSAGE_TARGET, full_thread=True) self._trace.append_trace_report(val)
[docs] def serialize(self, msg_format: DIDCommVersion = DIDCommVersion.v1, **kwargs): """Return serialized message in format specified.""" if msg_format is DIDCommVersion.v2: raise NotImplementedError("DIDComm v2 is not yet supported") return super().serialize(**kwargs)
[docs] @classmethod def deserialize( cls, value: dict, msg_format: DIDCommVersion = DIDCommVersion.v1, **kwargs ): """Return message object deserialized from value in format specified.""" if msg_format is DIDCommVersion.v2: raise NotImplementedError("DIDComm v2 is not yet supported") return super().deserialize(value, **kwargs)
[docs]class AgentMessageSchema(BaseModelSchema): """AgentMessage schema."""
[docs] class Meta: """AgentMessageSchema metadata.""" model_class = None signed_fields = None unknown = EXCLUDE
# Avoid clobbering keywords _type = fields.Str( data_key="@type", required=False, metadata={ "description": "Message type", "example": "https://didcomm.org/my-family/1.0/my-message-type", }, ) _id = fields.Str( data_key="@id", required=False, metadata={"description": "Message identifier", "example": UUID4_EXAMPLE}, ) def __init__(self, *args, **kwargs): """Initialize an instance of AgentMessageSchema. Raises: TypeError: If Meta.model_class has not been set """ super().__init__(*args, **kwargs) self._decorators = DecoratorSet() self._decorators_dict = None self._signatures = {} @pre_load def extract_decorators(self, data: Mapping, **kwargs): """Pre-load hook to extract the decorators and check the signed fields. Args: data: Incoming data to parse Returns: Parsed and modified data Raises: ValidationError: If a field signature does not correlate to a field in the message ValidationError: If the message defines both a field signature and a value for the same field ValidationError: If there is a missing field signature """ processed = self._decorators.extract_decorators(data, self.__class__) expect_fields = resolve_meta_property(self, "signed_fields") or () found_signatures = {} for field_name, field in self._decorators.fields.items(): if "sig" in field: if field_name not in expect_fields: raise ValidationError( f"Encountered unexpected field signature: {field_name}" ) if field_name in processed: raise ValidationError( f"Message defines both field signature and value: {field_name}" ) found_signatures[field_name] = field["sig"] processed[field_name], _ = field["sig"].decode() # _ = timestamp for field_name in expect_fields: if field_name not in found_signatures: raise ValidationError(f"Expected field signature: {field_name}") return processed @post_load def populate_decorators(self, obj, **kwargs): """Post-load hook to populate decorators on the message. Args: obj: The AgentMessage object Returns: The AgentMessage object with populated decorators """ obj._decorators = self._decorators return obj @pre_dump def check_dump_decorators(self, obj, **kwargs): """Pre-dump hook to validate and load the message decorators. Args: obj: The AgentMessage object Raises: BaseModelError: If a decorator does not validate """ decorators = obj._decorators.copy() signatures = OrderedDict() for name, field in decorators.fields.items(): if "sig" in field: signatures[name] = field["sig"].serialize() del field["sig"] self._decorators_dict = decorators.to_dict() self._signatures = signatures # check existence of signatures expect_fields = resolve_meta_property(self, "signed_fields") or () for field_name in expect_fields: if field_name not in self._signatures: raise BaseModelError( "Missing signature for field: {}".format(field_name) ) return obj @post_dump def dump_decorators(self, data, **kwargs): """Post-dump hook to write the decorators to the serialized output. Args: obj: The serialized data Returns: The modified data """ result = OrderedDict() for key in ("@type", "@id"): if key in data: result[key] = data.pop(key) result.update(self._decorators_dict) result.update(data) return result @post_dump def replace_signatures(self, data, **kwargs): """Post-dump hook to write the signatures to the serialized output. Args: obj: The serialized data Returns: The modified data """ for field_name, sig in self._signatures.items(): del data[field_name] data["{}~sig".format(field_name)] = sig return data