Source code for acapy_agent.messaging.models.base_record

"""Classes for BaseStorage-based record management."""

import json
import logging
import sys
from datetime import datetime
from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union

from marshmallow import fields
from uuid_utils import uuid4

from ...cache.base import BaseCache
from ...config.settings import BaseSettings
from ...core.profile import ProfileSession
from ...storage.base import (
    DEFAULT_PAGE_SIZE,
    BaseStorage,
    StorageDuplicateError,
    StorageNotFoundError,
)
from ...storage.record import StorageRecord
from ..util import datetime_to_str, time_now
from ..valid import ISO8601_DATETIME_EXAMPLE, ISO8601_DATETIME_VALIDATE
from .base import BaseModel, BaseModelError, BaseModelSchema

LOGGER = logging.getLogger(__name__)


RecordType = TypeVar("RecordType", bound="BaseRecord")


[docs] def match_post_filter( record: dict, post_filter: dict, positive: bool = True, alt: bool = False, ) -> bool: """Determine if a record value matches the post-filter. Args: record: record to check post_filter: filter to apply (empty or None filter matches everything) positive: whether matching all filter criteria positively or negatively alt: set to match any (positive=True) value or miss all (positive=False) values in post_filter """ if not post_filter: return True if alt: return ( positive and all( record.get(k) and record.get(k) in alts for k, alts in post_filter.items() ) ) or ( (not positive) and all( record.get(k) and record.get(k) not in alts for k, alts in post_filter.items() ) ) for k, v in post_filter.items(): if record.get(k) != v: return not positive return positive
[docs] class BaseRecord(BaseModel): """Represents a single storage record."""
[docs] class Meta: """BaseRecord metadata."""
DEFAULT_CACHE_TTL = 60 RECORD_ID_NAME = "id" RECORD_TYPE = None RECORD_TOPIC: Optional[str] = None EVENT_NAMESPACE: str = "acapy::record" LOG_STATE_FLAG = None TAG_NAMES = {"state"} STATE_DELETED = "deleted" def __init__( self, id: Optional[str] = None, state: Optional[str] = None, *, created_at: Union[str, datetime, None] = None, updated_at: Union[str, datetime, None] = None, new_with_id: bool = False, ): """Initialize a new BaseRecord.""" if not self.RECORD_TYPE: raise TypeError( "Cannot instantiate abstract class {} with no RECORD_TYPE".format( self.__class__.__name__ ) ) self._id = id self._last_state = state self._new_with_id = new_with_id self.state = state self.created_at = datetime_to_str(created_at) self.updated_at = datetime_to_str(updated_at)
[docs] @classmethod def from_storage(cls, record_id: str, record: Mapping[str, Any]): """Initialize a record from its stored representation. Args: record_id: The unique record identifier record: The stored representation """ record_id_name = cls.RECORD_ID_NAME if record_id_name in record: raise ValueError(f"Duplicate {record_id_name} inputs; {record}") params = dict(**record) params[record_id_name] = record_id return cls(**params)
[docs] @classmethod def get_tag_map(cls) -> Mapping[str, str]: """Accessor for the set of defined tags.""" return {tag.lstrip("~"): tag for tag in cls.TAG_NAMES or ()}
@property def storage_record(self) -> StorageRecord: """Accessor for a `StorageRecord` representing this record.""" return StorageRecord( self.RECORD_TYPE, json.dumps(self.value), self.tags, self._id ) @property def record_value(self) -> dict: """Accessor to define custom properties for the JSON record value.""" return {} @property def value(self) -> dict: """Accessor for the JSON record value generated for this record.""" ret = self.strip_tag_prefix(self.tags) ret.update({"created_at": self.created_at, "updated_at": self.updated_at}) ret.update(self.record_value) return ret @property def record_tags(self) -> dict: """Accessor to define implementation-specific tags.""" return { tag: getattr(self, prop) for (prop, tag) in self.get_tag_map().items() if getattr(self, prop) is not None } @property def tags(self) -> dict: """Accessor for the record tags generated for this record.""" tags = self.record_tags return tags
[docs] @classmethod async def get_cached_key(cls, session: ProfileSession, cache_key: str): """Shortcut method to fetch a cached key value. Args: session: The profile session to use cache_key: The unique cache identifier """ if not cache_key: return cache = session.inject_or(BaseCache) if cache: return await cache.get(cache_key)
[docs] @classmethod async def set_cached_key( cls, session: ProfileSession, cache_key: str, value: Any, ttl=None ): """Shortcut method to set a cached key value. Args: session: The profile session to use cache_key: The unique cache identifier value: The value to cache ttl: The cache ttl """ if not cache_key: return cache = session.inject_or(BaseCache) if cache: await cache.set(cache_key, value, ttl or cls.DEFAULT_CACHE_TTL)
[docs] @classmethod async def clear_cached_key(cls, session: ProfileSession, cache_key: str): """Shortcut method to clear a cached key value, if any. Args: session: The profile session to use cache_key: The unique cache identifier """ if not cache_key: return cache = session.inject_or(BaseCache) if cache: await cache.clear(cache_key)
[docs] @classmethod async def retrieve_by_id( cls: Type[RecordType], session: ProfileSession, record_id: str, *, for_update=False, ) -> RecordType: """Retrieve a stored record by ID. Args: session: The profile session to use record_id: The ID of the record to find for_update: Whether to lock the record for update """ storage = session.inject(BaseStorage) result = await storage.get_record( cls.RECORD_TYPE, record_id, options={"forUpdate": for_update} ) vals = json.loads(result.value) return cls.from_storage(record_id, vals)
[docs] @classmethod async def retrieve_by_tag_filter( cls: Type[RecordType], session: ProfileSession, tag_filter: dict, post_filter: Optional[dict] = None, *, for_update=False, ) -> RecordType: """Retrieve a record by tag filter. Args: cls: The record class session: The profile session to use tag_filter: The filter dictionary to apply post_filter: Additional value filters to apply matching positively, with sequence values specifying alternatives to match (hit any) for_update: Whether to lock the record for update """ storage = session.inject(BaseStorage) rows = await storage.find_all_records( cls.RECORD_TYPE, cls.prefix_tag_filter(tag_filter), options={"forUpdate": for_update}, ) found = None for record in rows: vals = json.loads(record.value) if match_post_filter(vals, post_filter, alt=False): if found: raise StorageDuplicateError( "Multiple {} records located for {}{}".format( cls.__name__, tag_filter, f", {post_filter}" if post_filter else "", ) ) found = cls.from_storage(record.id, vals) if not found: raise StorageNotFoundError( "{} record not found for {}{}".format( cls.__name__, tag_filter, f", {post_filter}" if post_filter else "" ) ) return found
[docs] @classmethod async def query( cls: Type[RecordType], session: ProfileSession, tag_filter: Optional[dict] = None, *, limit: Optional[int] = None, offset: Optional[int] = None, order_by: Optional[str] = None, descending: bool = False, post_filter_positive: Optional[dict] = None, post_filter_negative: Optional[dict] = None, alt: bool = False, ) -> Sequence[RecordType]: """Query stored records. Args: session: The profile session to use tag_filter: An optional dictionary of tag filter clauses limit: The maximum number of records to retrieve offset: The offset to start retrieving records from order_by: An optional field by which to order the records. descending: Whether to order the records in descending order. post_filter_positive: Additional value filters to apply matching positively post_filter_negative: Additional value filters to apply matching negatively alt: set to match any (positive=True) value or miss all (positive=False) values in post_filter """ storage = session.inject(BaseStorage) tag_query = cls.prefix_tag_filter(tag_filter) post_filter = post_filter_positive or post_filter_negative # set flag to indicate if pagination is requested or not, then set defaults paginated = limit is not None or offset is not None limit = limit or DEFAULT_PAGE_SIZE offset = offset or 0 if not post_filter and paginated: # Only fetch paginated records if post-filter is not being applied rows = await storage.find_paginated_records( type_filter=cls.RECORD_TYPE, tag_query=tag_query, limit=limit, offset=offset, order_by=order_by, descending=descending, ) else: rows = await storage.find_all_records( type_filter=cls.RECORD_TYPE, tag_query=tag_query, order_by=order_by, descending=descending, ) num_results_post_filter = 0 # used if applying pagination post-filter num_records_to_match = limit + offset # ignored if not paginated result = [] for record in rows: try: vals = json.loads(record.value) if not post_filter: # pagination would already be applied if requested result.append(cls.from_storage(record.id, vals)) else: continue_processing = ( not paginated or num_results_post_filter < num_records_to_match ) if not continue_processing: break post_filter_match = match_post_filter( vals, post_filter_positive, positive=True, alt=alt ) and match_post_filter( vals, post_filter_negative, positive=False, alt=alt ) if not post_filter_match: continue if num_results_post_filter >= offset: # append only after offset result.append(cls.from_storage(record.id, vals)) num_results_post_filter += 1 except (BaseModelError, json.JSONDecodeError, TypeError) as err: raise BaseModelError(f"{err}, for record id {record.id}") return result
[docs] async def save( self, session: ProfileSession, *, reason: Optional[str] = None, log_params: Mapping[str, Any] = None, log_override: bool = False, event: Optional[bool] = None, ) -> str: """Persist the record to storage. Args: session: The profile session to use reason: A reason to add to the log log_params: Additional parameters to log log_override: Override configured logging regimen, print to stderr instead event: Flag to override whether the event is sent """ new_record = None log_reason = reason or ("Updated record" if self._id else "Created record") try: self.updated_at = time_now() storage = session.inject(BaseStorage) if self._id and not self._new_with_id: record = self.storage_record await storage.update_record(record, record.value, record.tags) new_record = False else: if not self._id: self._id = str(uuid4()) self.created_at = self.updated_at await storage.add_record(self.storage_record) new_record = True self._new_with_id = False finally: params = {self.RECORD_TYPE: self.serialize()} if log_params: params.update(log_params) if new_record is None: log_reason = f"FAILED: {log_reason}" self.log_state( log_reason, params, override=log_override, settings=session.settings ) await self.post_save(session, new_record, self._last_state, event) self._last_state = self.state return self._id
[docs] async def post_save( self, session: ProfileSession, new_record: bool, last_state: Optional[str], event: Optional[bool] = None, ): """Perform post-save actions. Args: session: The profile session to use new_record: Flag indicating if the record was just created last_state: The previous state value event: Flag to override whether the event is sent """ if event is None: event = new_record or (last_state != self.state) if event: await self.emit_event(session, self.serialize())
[docs] async def delete_record(self, session: ProfileSession): """Remove the stored record. Args: session: The profile session to use """ if self._id: storage = session.inject(BaseStorage) if self.state: self._previous_state = self.state self.state = BaseRecord.STATE_DELETED await self.emit_event(session, self.serialize()) await storage.delete_record(self.storage_record)
[docs] async def emit_event(self, session: ProfileSession, payload: Optional[Any] = None): """Emit an event. Args: session: The profile session to use payload: The event payload """ if not self.RECORD_TOPIC: return if self.state: topic = f"{self.EVENT_NAMESPACE}::{self.RECORD_TOPIC}::{self.state}" else: topic = f"{self.EVENT_NAMESPACE}::{self.RECORD_TOPIC}" if not payload: payload = self.serialize() await session.emit_event(topic, payload)
[docs] @classmethod def log_state( cls, msg: str, params: Optional[dict] = None, settings: Optional[BaseSettings] = None, override: bool = False, ): """Print a message with increased visibility (for testing).""" if override or ( cls.LOG_STATE_FLAG and settings and settings.get(cls.LOG_STATE_FLAG) ): out = msg + "\n" if params: for k, v in params.items(): out += f" {k}: {v}\n" print(out, file=sys.stderr)
[docs] @classmethod def strip_tag_prefix(cls, tags: dict): """Strip tilde from unencrypted tag names.""" return {(k[1:] if "~" in k else k): v for (k, v) in tags.items()} if tags else {}
[docs] @classmethod def prefix_tag_filter(cls, tag_filter: dict): """Prefix unencrypted tags used in the tag filter.""" ret = None if tag_filter: tag_map = cls.get_tag_map() ret = {} for k, v in tag_filter.items(): if k in ("$or", "$and") and isinstance(v, list): ret[k] = [cls.prefix_tag_filter(clause) for clause in v] elif k == "$not" and isinstance(v, dict): ret[k] = cls.prefix_tag_filter(v) else: ret[tag_map.get(k, k)] = v return ret
def __eq__(self, other: Any) -> bool: """Comparison between records.""" if type(other) is type(self): return self.value == other.value and self.tags == other.tags return False
[docs] @classmethod def get_attributes_by_prefix(cls, prefix: str, walk_mro: bool = True): """List all values for attributes with common prefix. Args: prefix: Common prefix to look for walk_mro: Walk MRO to find attributes inherited from superclasses """ bases = cls.__mro__ if walk_mro else [cls] return [ vars(base)[name] for base in bases for name in vars(base) if name.startswith(prefix) ]
[docs] class BaseExchangeRecord(BaseRecord): """Represents a base record with event tracing capability.""" def __init__( self, id: Optional[str] = None, state: Optional[str] = None, *, trace: bool = False, **kwargs, ): """Initialize a new BaseExchangeRecord.""" super().__init__(id, state, **kwargs) self.trace = trace def __eq__(self, other: Any) -> bool: """Comparison between records.""" if type(other) is type(self): return ( self.value == other.value and self.tags == other.tags and self.trace == other.trace ) return False
[docs] class BaseRecordSchema(BaseModelSchema): """Schema to allow serialization/deserialization of base records."""
[docs] class Meta: """BaseRecordSchema metadata.""" model_class = None
state = fields.Str( required=False, metadata={"description": "Current record state", "example": "active"}, ) created_at = fields.Str( required=False, validate=ISO8601_DATETIME_VALIDATE, metadata={ "description": "Time of record creation", "example": ISO8601_DATETIME_EXAMPLE, }, ) updated_at = fields.Str( required=False, validate=ISO8601_DATETIME_VALIDATE, metadata={ "description": "Time of last record update", "example": ISO8601_DATETIME_EXAMPLE, }, )
[docs] class BaseExchangeSchema(BaseRecordSchema): """Base schema for exchange records."""
[docs] class Meta: """BaseExchangeSchema metadata.""" model_class = BaseExchangeRecord
trace = fields.Boolean( required=False, dump_default=False, metadata={ "description": "Record trace information, based on agent configuration" }, )