Source code for acapy_agent.database_manager.databases.sqlite_normalized.database

"""SQLite normalized database implementation."""

import asyncio
import logging
import sqlite3
import threading
import time
from typing import Generator, Optional

try:
    # Try new sqlcipher3 first (SQLite 3.46+)
    import sqlcipher3 as sqlcipher
except ImportError:
    sqlcipher = None
from ...category_registry import get_release
from ...db_types import Entry
from ...interfaces import AbstractDatabaseStore
from ...wql_normalized.query import query_from_str
from ...wql_normalized.tags import query_to_tagquery
from ..errors import DatabaseError, DatabaseErrorCode
from .connection_pool import ConnectionPool

LOGGER = logging.getLogger(__name__)


[docs] def enc_name(name: str) -> str: """Encode name for database storage. Args: name: Name to encode Returns: Encoded name """ return name
[docs] def enc_value(value: str) -> str: """Encode value for database storage. Args: value: Value to encode Returns: Encoded value """ return value
[docs] class SqliteDatabase(AbstractDatabaseStore): """SQLite database implementation for normalized storage.""" def __init__( self, pool: ConnectionPool, default_profile: str, path: str, release_number: str = "release_0", ): """Initialize SQLite database.""" self.lock = threading.RLock() self.pool = pool self.default_profile = default_profile self.path = path self.release_number = release_number # The self.release_number comes # from the schema_release_number stored in the config table self.active_sessions = [] self.session_creation_times = {} self.max_sessions = int(pool.pool_size * 0.75) # need load test self._monitoring_task: Optional[asyncio.Task] = None try: self.default_profile_id = self._get_profile_id(default_profile) except Exception as e: LOGGER.error( "Failed to initialize default profile ID for '%s': %s", default_profile, str(e), ) raise DatabaseError( code=DatabaseErrorCode.PROFILE_NOT_FOUND, message=( f"Failed to initialize default profile ID for '{default_profile}'" ), actual_error=str(e), )
[docs] async def start_monitoring(self): """Start monitoring active database sessions.""" if self._monitoring_task is None or self._monitoring_task.done(): self._monitoring_task = asyncio.create_task(self._monitor_active_sessions())
async def _monitor_active_sessions(self): while True: await asyncio.sleep(5) # check every 5 secs with self.lock: if self.active_sessions: current_time = time.time() for session in self.active_sessions[:]: session_id = id(session) creation_time = self.session_creation_times.get(session_id, 0) age_seconds = current_time - creation_time if age_seconds > 5: # close sessions older than 5secs try: await session.close() except Exception: pass def _get_profile_id(self, profile_name: str) -> int: with self.lock: conn = self.pool.get_connection() try: cursor = conn.cursor() cursor.execute("SELECT id FROM profiles WHERE name = ?", (profile_name,)) row = cursor.fetchone() if row: return row[0] LOGGER.error("Profile '%s' not found", profile_name) raise DatabaseError( code=DatabaseErrorCode.PROFILE_NOT_FOUND, message=f"Profile '{profile_name}' not found", ) except Exception as e: LOGGER.error( "Failed to retrieve profile ID for '%s': %s", profile_name, str(e) ) raise DatabaseError( code=DatabaseErrorCode.QUERY_ERROR, message=f"Failed to retrieve profile ID for '{profile_name}'", actual_error=str(e), ) finally: self.pool.return_connection(conn)
[docs] async def create_profile(self, name: str = None) -> str: """Create a new profile in the database. Args: name: Profile name to create Returns: str: The created profile name """ name = name or "new_profile" def _create(): with self.lock: conn = self.pool.get_connection() try: cursor = conn.cursor() cursor.execute( "INSERT OR IGNORE INTO profiles (name, profile_key) " "VALUES (?, NULL)", (name,), ) if cursor.rowcount == 0: LOGGER.error("Profile '%s' already exists", name) raise DatabaseError( code=DatabaseErrorCode.PROFILE_ALREADY_EXISTS, message=f"Profile '{name}' already exists", ) if not hasattr(self, "is_txn") or not self.is_txn: conn.commit() return name except Exception as e: if not hasattr(self, "is_txn") or not self.is_txn: conn.rollback() LOGGER.error("Failed to create profile '%s': %s", name, str(e)) raise DatabaseError( code=DatabaseErrorCode.QUERY_ERROR, message=f"Failed to create profile '{name}'", actual_error=str(e), ) finally: self.pool.return_connection(conn) return await asyncio.to_thread(_create)
[docs] async def get_profile_name(self) -> str: """Get the default profile name. Returns: str: Default profile name """ return self.default_profile
[docs] async def remove_profile(self, name: str) -> bool: """Remove a profile from the database. Args: name: Profile name to remove Returns: bool: True if removed successfully """ def _remove(): with self.lock: conn = self.pool.get_connection() try: cursor = conn.cursor() cursor.execute("DELETE FROM profiles WHERE name = ?", (name,)) result = cursor.rowcount > 0 if not hasattr(self, "is_txn") or not self.is_txn: conn.commit() return result except Exception as e: if not hasattr(self, "is_txn") or not self.is_txn: conn.rollback() LOGGER.error("Failed to remove profile '%s': %s", name, str(e)) raise DatabaseError( code=DatabaseErrorCode.QUERY_ERROR, message=f"Failed to remove profile '{name}'", actual_error=str(e), ) finally: self.pool.return_connection(conn) return await asyncio.to_thread(_remove)
[docs] async def rekey(self, key_method: str = None, pass_key: str = None): """Rekey the database with new encryption. Args: key_method: Key method to use pass_key: Password key for encryption """ def _rekey(): with self.lock: conn = self.pool.get_connection() try: cursor = conn.cursor() cursor.execute("PRAGMA cipher_version;") if not cursor.fetchone()[0]: LOGGER.error("Database is not encrypted") raise DatabaseError( code=DatabaseErrorCode.DATABASE_NOT_ENCRYPTED, message="Database is not encrypted", ) cursor.execute(f"PRAGMA rekey = '{pass_key}'") if not hasattr(self, "is_txn") or not self.is_txn: conn.commit() self.pool.encryption_key = pass_key except Exception as e: if not hasattr(self, "is_txn") or not self.is_txn: conn.rollback() LOGGER.error("Failed to rekey database: %s", str(e)) raise DatabaseError( code=DatabaseErrorCode.QUERY_ERROR, message="Failed to rekey database", actual_error=str(e), ) finally: self.pool.return_connection(conn) await asyncio.to_thread(_rekey)
[docs] def scan( self, profile: Optional[str], category: str, tag_filter: str | dict = None, offset: int = None, limit: int = None, order_by: Optional[str] = None, descending: bool = False, ) -> Generator[Entry, None, None]: """Scan entries in the database with filtering and pagination. Args: profile: Profile name to scan category: Category to scan tag_filter: Tag filter criteria offset: Offset for pagination limit: Limit for pagination order_by: Column to order by descending: Whether to sort descending Yields: Entry: Database entries matching criteria """ handlers, _, _ = get_release(self.release_number, "sqlite") handler = handlers.get(category, handlers["default"]) profile_id = self._get_profile_id(profile or self.default_profile) tag_query = None if tag_filter: wql_query = query_from_str(tag_filter) tag_query = query_to_tagquery(wql_query) with self.lock: conn = self.pool.get_connection() try: cursor = conn.cursor() for entry in handler.scan( cursor, profile_id, category, tag_query, offset, limit, order_by, descending, ): yield entry except DatabaseError as e: LOGGER.error("Failed to execute scan query: %s", str(e)) raise except Exception as e: LOGGER.error("Failed to execute scan query: %s", str(e)) raise DatabaseError( code=DatabaseErrorCode.QUERY_ERROR, message="Failed to execute scan query", actual_error=str(e), ) finally: self.pool.return_connection(conn)
[docs] def scan_keyset( self, profile: Optional[str], category: str, tag_filter: str | dict = None, last_id: Optional[int] = None, limit: int = None, order_by: Optional[str] = None, descending: bool = False, ) -> Generator[Entry, None, None]: """Scan entries using keyset pagination. Args: profile: Profile name to scan category: Category to scan tag_filter: Tag filter criteria last_id: Last ID for cursor-based pagination limit: Limit for pagination order_by: Column to order by descending: Whether to sort descending Yields: Entry: Database entries """ handlers, _, _ = get_release(self.release_number, "sqlite") handler = handlers.get(category, handlers["default"]) profile_id = self._get_profile_id(profile or self.default_profile) tag_query = None if tag_filter: wql_query = query_from_str(tag_filter) tag_query = query_to_tagquery(wql_query) with self.lock: conn = self.pool.get_connection() try: cursor = conn.cursor() for entry in handler.scan_keyset( cursor, profile_id, category, tag_query, last_id, limit, order_by, descending, ): yield entry except DatabaseError as e: LOGGER.error("Failed to execute scan_keyset query: %s", str(e)) raise except Exception as e: LOGGER.error("Failed to execute scan_keyset query: %s", str(e)) raise DatabaseError( code=DatabaseErrorCode.QUERY_ERROR, message="Failed to execute scan_keyset query", actual_error=str(e), ) finally: self.pool.return_connection(conn)
[docs] def session(self, profile: str = None, release_number: str = "release_0"): """Create a context manager for database session. Args: profile: Profile name to use release_number: Release number for schema versioning Returns: SqliteSession: Database session context manager """ from .session import SqliteSession with self.lock: if len(self.active_sessions) >= self.max_sessions: LOGGER.error( "Maximum number of active sessions reached: %d", self.max_sessions ) raise DatabaseError( code=DatabaseErrorCode.CONNECTION_POOL_EXHAUSTED, message="Maximum number of active sessions reached", ) sess = SqliteSession( self, profile or self.default_profile, False, self.release_number ) with self.lock: self.active_sessions.append(sess) self.session_creation_times[id(sess)] = time.time() LOGGER.debug( "[session] Active sessions: %d, session_id=%s", len(self.active_sessions), id(sess), ) return sess
[docs] def transaction(self, profile: str = None, release_number: str = "release_0"): """Create a transaction context manager. Args: profile: Profile name to use release_number: Release number for schema versioning Returns: SqliteSession: Database transaction context manager """ from .session import SqliteSession with self.lock: if len(self.active_sessions) >= self.max_sessions: LOGGER.error( "Maximum number of active sessions reached: %d", self.max_sessions ) raise DatabaseError( code=DatabaseErrorCode.CONNECTION_POOL_EXHAUSTED, message="Maximum number of active sessions reached", ) sess = SqliteSession( self, profile or self.default_profile, True, self.release_number ) with self.lock: self.active_sessions.append(sess) self.session_creation_times[id(sess)] = time.time() LOGGER.debug( "[session] Active sessions: %d, session_id=%s", len(self.active_sessions), id(sess), ) return sess
[docs] def close(self, remove: bool = False): """Close the database and optionally remove the file. Args: remove: Whether to remove the database file """ try: # Cancel background monitoring task if running if self._monitoring_task and not self._monitoring_task.done(): self._monitoring_task.cancel() try: asyncio.get_event_loop().run_until_complete(self._monitoring_task) except Exception: pass finally: self._monitoring_task = None if self.pool: checkpoint_conn = None try: checkpoint_conn = ( sqlite3.connect(self.path, check_same_thread=False) if not self.pool.encryption_key else sqlcipher.connect(self.path, check_same_thread=False) ) if self.pool.encryption_key: checkpoint_conn.execute( f"PRAGMA key = '{self.pool.encryption_key}'" ) checkpoint_conn.execute("PRAGMA cipher_migrate") checkpoint_conn.execute("PRAGMA cipher_compatibility = 4") cursor = checkpoint_conn.cursor() cursor.execute("PRAGMA wal_checkpoint(TRUNCATE)") except Exception as e: LOGGER.error("WAL checkpoint failed: %s", str(e)) finally: if checkpoint_conn: checkpoint_conn.close() try: self.pool.close() except Exception as e: LOGGER.error("Failed to close connection pool: %s", str(e)) raise DatabaseError( code=DatabaseErrorCode.CONNECTION_ERROR, message="Failed to close connection pool", actual_error=str(e), ) except Exception as e: LOGGER.error("Failed to close database: %s", str(e)) raise DatabaseError( code=DatabaseErrorCode.CONNECTION_ERROR, message="Failed to close database", actual_error=str(e), )