Source code for acapy_agent.storage.kanon_storage

"""Kanon storage implementation for non-secrets storage."""

import asyncio
import inspect
import logging
from typing import Mapping, Optional, Sequence

from ..core.profile import Profile
from ..database_manager.dbstore import DBStoreError, DBStoreErrorCode, DBStoreSession
from .base import (
    DEFAULT_PAGE_SIZE,
    BaseStorage,
    BaseStorageSearch,
    BaseStorageSearchSession,
    validate_record,
)
from .error import (
    StorageDuplicateError,
    StorageError,
    StorageNotFoundError,
    StorageSearchError,
)
from .record import StorageRecord

LOGGER = logging.getLogger(__name__)

ERR_FETCH_SEARCH_RESULTS = "Error when fetching search results"


[docs] class KanonStorage(BaseStorage): """Kanon Non-Secrets interface.""" def __init__(self, session: Profile): """Initialize KanonStorage with a profile session.""" self._session = session @property def session(self) -> DBStoreSession: """Get the database session.""" return self._session.dbstore_handle
[docs] async def add_record( self, record: StorageRecord, session: Optional[DBStoreSession] = None ): """Add a new record to storage.""" validate_record(record) if session is None: async with self._session.store.session() as temp_session: await self._add_record(record, temp_session) else: await self._add_record(record, session)
async def _add_record(self, record: StorageRecord, session: DBStoreSession): try: await self._call_handle_or_session( session, "insert", record.type, record.id, record.value, record.tags ) except DBStoreError as err: if err.code == DBStoreErrorCode.DUPLICATE: raise StorageDuplicateError( f"Duplicate record: {record.type}/{record.id}" ) from None raise StorageError("Error when adding storage record") from err
[docs] async def get_record( self, record_type: str, record_id: str, options: Optional[Mapping] = None, session: Optional[DBStoreSession] = None, ) -> StorageRecord: """Retrieve a single record by type and ID.""" if not record_type: raise StorageError("Record type not provided") if not record_id: raise StorageError("Record ID not provided") for_update = bool(options and options.get("forUpdate")) if session is None: async with self._session.store.session() as temp_session: return await self._get_record( record_type, record_id, for_update, temp_session ) return await self._get_record(record_type, record_id, for_update, session)
async def _get_record( self, record_type: str, record_id: str, for_update: bool, session: DBStoreSession ) -> StorageRecord: try: item = await self._call_handle_or_session( session, "fetch", record_type, record_id, for_update=for_update ) except DBStoreError as err: raise StorageError("Error when fetching storage record") from err if not item: raise StorageNotFoundError(f"Record not found: {record_type}/{record_id}") return StorageRecord( type=item.category, id=item.name, value=item.value, tags=item.tags or {}, )
[docs] async def update_record( self, record: StorageRecord, value: str, tags: Mapping, session: Optional[DBStoreSession] = None, ): """Update an existing record's value and tags.""" validate_record(record) if session is None: async with self._session.store.session() as temp_session: await self._update_record(record, value, tags, temp_session) else: await self._update_record(record, value, tags, session)
async def _update_record( self, record: StorageRecord, value: str, tags: Mapping, session: DBStoreSession ): try: item = await self._call_handle_or_session( session, "fetch", record.type, record.id, for_update=True ) if not item: raise StorageNotFoundError(f"Record not found: {record.type}/{record.id}") await self._call_handle_or_session( session, "replace", record.type, record.id, value, tags ) except DBStoreError as err: if err.code == DBStoreErrorCode.NOT_FOUND: raise StorageNotFoundError( f"Record not found: {record.type}/{record.id}" ) from None raise StorageError("Error when updating storage record value") from err
[docs] async def delete_record( self, record: StorageRecord, session: Optional[DBStoreSession] = None ): """Delete a record from storage.""" validate_record(record, delete=True) if session is None: async with self._session.store.session() as temp_session: await self._delete_record(record, temp_session) else: await self._delete_record(record, session)
async def _delete_record(self, record: StorageRecord, session: DBStoreSession): try: await self._call_handle_or_session(session, "remove", record.type, record.id) except DBStoreError as err: if err.code == DBStoreErrorCode.NOT_FOUND: raise StorageNotFoundError( f"Record not found: {record.type}/{record.id}" ) from None raise StorageError("Error when removing storage record") from err
[docs] async def find_record( self, type_filter: str, tag_query: Mapping, options: Optional[Mapping] = None, session: Optional[DBStoreSession] = None, ) -> StorageRecord: """Find a single record matching the type and tag query.""" for_update = bool(options and options.get("forUpdate")) if session is None: async with self._session.store.session() as temp_session: return await self._find_record( type_filter, tag_query, for_update, temp_session ) return await self._find_record(type_filter, tag_query, for_update, session)
async def _find_record( self, type_filter: str, tag_query: Mapping, for_update: bool, session: DBStoreSession, ) -> StorageRecord: try: results = await self._call_handle_or_session( session, "fetch_all", type_filter, tag_query, limit=2, for_update=for_update, ) except DBStoreError as err: raise StorageError("Error when finding storage record") from err if len(results) > 1: raise StorageDuplicateError("Duplicate records found") if not results: raise StorageNotFoundError("Record not found") row = results[0] return StorageRecord( type=row.category, id=row.name, value=row.value, tags=row.tags, )
[docs] async def find_paginated_records( self, type_filter: str, tag_query: Optional[Mapping] = None, limit: int = DEFAULT_PAGE_SIZE, offset: int = 0, order_by: Optional[str] = None, descending: bool = False, ) -> Sequence[StorageRecord]: """Retrieve paginated records using DBStore.scan.""" LOGGER.debug( "find_paginated_records: type=%s, tags=%s, limit=%s, " "offset=%s, order=%s, desc=%s", type_filter, tag_query, limit, offset, order_by, descending, ) results = [] scan = self._session.store.scan( category=type_filter, tag_filter=tag_query, limit=limit, offset=offset, profile=self._session.profile.name, order_by=order_by, descending=descending, ) async for row in scan: results.append( StorageRecord( type=row.category, id=row.name, value=row.value, tags=row.tags, ) ) return results
[docs] async def find_paginated_records_keyset( self, type_filter: str, tag_query: Optional[Mapping] = None, last_id: int = None, limit: int = DEFAULT_PAGE_SIZE, order_by: Optional[str] = None, descending: bool = False, ) -> Sequence[StorageRecord]: """Retrieve paginated records using DBStore.scan_keyset.""" LOGGER.debug( "find_paginated_records_keyset: type=%s, tags=%s, last_id=%s, " "limit=%s, order=%s, desc=%s", type_filter, tag_query, last_id, limit, order_by, descending, ) results = [] scan = self._session.store.scan_keyset( category=type_filter, tag_filter=tag_query, last_id=last_id, limit=limit, profile=self._session.profile.name, order_by=order_by, descending=descending, ) async for row in scan: results.append( StorageRecord( type=row.category, id=row.name, value=row.value, tags=row.tags, ) ) return results
[docs] async def find_all_records( self, type_filter: str, tag_query: Optional[Mapping] = None, order_by: Optional[str] = None, descending: bool = False, options: Optional[Mapping] = None, session: Optional[DBStoreSession] = None, ) -> Sequence[StorageRecord]: """Retrieve all records matching the type and tag query.""" for_update = bool(options and options.get("forUpdate")) if session is None: async with self._session.store.session() as temp_session: return await self._find_all_records( type_filter, tag_query, order_by, descending, for_update, temp_session ) return await self._find_all_records( type_filter, tag_query, order_by, descending, for_update, session )
async def _find_all_records( self, type_filter: str, tag_query: Optional[Mapping], order_by: Optional[str], descending: bool, for_update: bool, session: DBStoreSession, ) -> Sequence[StorageRecord]: results = [] try: for row in await self._call_handle_or_session( session, "fetch_all", type_filter, tag_query, order_by=order_by, descending=descending, for_update=for_update, ): results.append( StorageRecord( type=row.category, id=row.name, value=row.value, tags=row.tags, ) ) except DBStoreError as err: raise StorageError("Failed to fetch records") from err return results
[docs] async def delete_all_records( self, type_filter: str, tag_query: Optional[Mapping] = None, session: Optional[DBStoreSession] = None, ): """Delete all records matching the type and tag query.""" if session is None: async with self._session.store.session() as temp_session: await self._delete_all_records(type_filter, tag_query, temp_session) else: await self._delete_all_records(type_filter, tag_query, session)
async def _delete_all_records( self, type_filter: str, tag_query: Optional[Mapping], session: DBStoreSession ): try: await self._call_handle_or_session( session, "remove_all", type_filter, tag_query ) except DBStoreError as err: raise StorageError("Error when deleting records") from err async def _call_handle_or_session(self, session, method_name: str, *args, **kwargs): """Call a DB session method handling both sync handle.* and async session.*. If session.handle.<method> exists and is synchronous, call it directly. If it is asynchronous (coroutine or async generator), delegate to session.<method>. Otherwise, call/await session.<method> appropriately. """ prefer_session_first = method_name in {"fetch_all", "remove_all"} if prefer_session_first: smethod = getattr(session, method_name, None) if smethod is not None and callable(smethod): try: if inspect.iscoroutinefunction(smethod) or inspect.isasyncgenfunction( smethod ): return await smethod(*args, **kwargs) return smethod(*args, **kwargs) except TypeError: handle = getattr(session, "handle", None) if handle is not None and hasattr(handle, method_name): hmethod = getattr(handle, method_name) if inspect.iscoroutinefunction(hmethod): return await hmethod(*args, **kwargs) if inspect.isasyncgenfunction(hmethod): results = [] async for item in hmethod(*args, **kwargs): results.append(item) return results return hmethod(*args, **kwargs) else: handle = getattr(session, "handle", None) if handle is not None and hasattr(handle, method_name): hmethod = getattr(handle, method_name) if callable(hmethod): if inspect.iscoroutinefunction(hmethod): return await hmethod(*args, **kwargs) if not inspect.isasyncgenfunction(hmethod): return hmethod(*args, **kwargs) smethod = getattr(session, method_name, None) if smethod is not None and callable(smethod): if inspect.iscoroutinefunction(smethod) or inspect.isasyncgenfunction( smethod ): return await smethod(*args, **kwargs) return smethod(*args, **kwargs) handle = getattr(session, "handle", None) if handle is not None and hasattr(handle, method_name): hmethod = getattr(handle, method_name) if callable(hmethod): if inspect.iscoroutinefunction(hmethod): return await hmethod(*args, **kwargs) if inspect.isasyncgenfunction(hmethod): results = [] async for item in hmethod(*args, **kwargs): results.append(item) return results return hmethod(*args, **kwargs) raise AttributeError(f"Session does not provide method {method_name}")
[docs] class KanonStorageSearch(BaseStorageSearch): """Kanon storage search interface.""" def __init__(self, profile: Profile): """Initialize KanonStorageSearch with a profile.""" self._profile = profile
[docs] def search_records( self, type_filter: str, tag_query: Optional[Mapping] = None, page_size: Optional[int] = None, options: Optional[Mapping] = None, ) -> "KanonStorageSearchSession": """Search for records.""" return KanonStorageSearchSession( self._profile, type_filter, tag_query, page_size, options )
[docs] class KanonStorageSearchSession(BaseStorageSearchSession): """Kanon storage search session.""" def __init__( self, profile, type_filter: str, tag_query: Mapping, page_size: Optional[int] = None, options: Optional[Mapping] = None, ): """Initialize search session with filter parameters.""" self.tag_query = tag_query self.type_filter = type_filter self.page_size = page_size or DEFAULT_PAGE_SIZE self._done = False self._profile = profile self._scan = None self._timeout_task = None @property def opened(self) -> bool: """Check if search is opened.""" return self._scan is not None @property def handle(self): """Get search handle.""" return self._scan def __aiter__(self): """Return async iterator.""" return self async def __anext__(self): """Get next item from search.""" if self._done: raise StorageSearchError("Search query is complete") await self._open() try: if hasattr(self._scan, "__anext__"): row = await self._scan.__anext__() elif inspect.isawaitable(self._scan): # Awaitable scan: will raise DBStoreError per test, map to # StorageSearchError await self._scan await self.close() raise StopAsyncIteration else: # Synchronous iterator fallback row = next(self._scan) LOGGER.debug("Fetched row: category=%s, name=%s", row.category, row.name) except DBStoreError as err: await self.close() raise StorageSearchError(ERR_FETCH_SEARCH_RESULTS) from err except StopAsyncIteration: await self.close() raise return StorageRecord( type=row.category, id=row.name, value=row.value, # DBStore returns a string from Entry.value tags=row.tags, )
[docs] async def fetch( self, max_count: Optional[int] = None, offset: Optional[int] = None ) -> Sequence[StorageRecord]: """Fetch records.""" if self._done: raise StorageSearchError("Search query is complete") limit = max_count or self.page_size await self._open(limit=limit, offset=offset) count = 0 ret = [] if not hasattr(self._scan, "__anext__") and inspect.isawaitable(self._scan): try: await self._scan except DBStoreError as err: await self.close() raise StorageSearchError(ERR_FETCH_SEARCH_RESULTS) from err # No rows yielded await self.close() return ret while count < limit: try: if hasattr(self._scan, "__anext__"): row = await self._scan.__anext__() else: row = next(self._scan) LOGGER.debug("Fetched row: category=%s, name=%s", row.category, row.name) ret.append( StorageRecord( type=row.category, id=row.name, value=row.value, tags=row.tags, ) ) count += 1 except DBStoreError as err: await self.close() raise StorageSearchError(ERR_FETCH_SEARCH_RESULTS) from err except StopAsyncIteration: break if not ret: await self.close() return ret
async def _open(self, offset: Optional[int] = None, limit: Optional[int] = None): if self._scan: return try: LOGGER.debug( "Opening scan for type_filter=%s, tag_query=%s, limit=%s, offset=%s", self.type_filter, self.tag_query, limit, offset, ) self._scan = self._profile.opened.db_store.scan( category=self.type_filter, tag_filter=self.tag_query, offset=offset, limit=limit, profile=self._profile.name, ) self._timeout_task = asyncio.create_task(self._timeout_close()) except DBStoreError as err: raise StorageSearchError("Error opening search query") from err async def _timeout_close(self): """Close the scan after a timeout to prevent leaks.""" await asyncio.sleep(30) if self._scan and not self._done: LOGGER.warning("Scan timeout reached, forcing closure") await self.close()
[docs] async def close(self): """Close search session.""" if self._timeout_task: self._timeout_task.cancel() self._timeout_task = None if self._scan: try: aclose = getattr(self._scan, "aclose", None) if aclose: await aclose() else: close = getattr(self._scan, "close", None) if close: res = close() if inspect.iscoroutine(res): await res LOGGER.debug("Closed KanonStorageSearchSession scan") except Exception: pass finally: self._scan = None self._done = True
async def __aexit__(self, exc_type, exc, tb): """Exit async context manager.""" await self.close() if exc_type: LOGGER.error("Exception in KanonStorageSearchSession: %s", exc) return False