Source code for acapy_agent.database_manager.databases.sqlite_normalized.connection_pool

"""Module docstring."""

import logging
import queue
import sqlite3
import threading
import time
from typing import Optional

try:
    # Use sqlcipher3 binary (SQLite 3.46+)
    import sqlcipher3 as sqlcipher
except ImportError:
    sqlcipher = None

from ..errors import DatabaseError, DatabaseErrorCode

LOGGER = logging.getLogger(__name__)

PRAGMA_CIPHER_COMPAT = "PRAGMA cipher_compatibility = 4"


[docs] class ConnectionPool: """Connection pool manager for SQLite databases.""" def __init__( self, db_path: str, pool_size: int, busy_timeout: float, encryption_key: Optional[str] = None, journal_mode: str = "WAL", locking_mode: str = "NORMAL", synchronous: str = "FULL", shared_cache: bool = True, ): """Initialize SQLite connection pool.""" self.db_path = db_path self.pool_size = pool_size self.busy_timeout = busy_timeout self.encryption_key = encryption_key self.journal_mode = journal_mode self.locking_mode = locking_mode self.synchronous = synchronous self.shared_cache = shared_cache self.pool = queue.Queue(maxsize=pool_size) self.lock = threading.Lock() self.connection_ids = {} self.connection_count = 0 self._keep_alive_running = threading.Event() self._keep_alive_running.set() self.keep_alive_thread = threading.Thread(target=self._keep_alive, daemon=True) try: for _ in range(pool_size): conn = self._create_connection() self.pool.put(conn) self.keep_alive_thread.start() except Exception as e: LOGGER.error("Failed to initialize connection pool: %s", str(e)) raise DatabaseError( code=DatabaseErrorCode.CONNECTION_ERROR, message="Failed to initialize connection pool", actual_error=str(e), ) def _keep_alive(self): # Allow configuring keep-alive interval for tests import os keep_alive_interval = int(os.environ.get("SQLITE_KEEPALIVE_INTERVAL", "10")) while self._keep_alive_running.is_set(): # Sleep in 1-second increments to be responsive to shutdown for _ in range(keep_alive_interval): if not self._keep_alive_running.is_set(): return time.sleep(1) with self.lock: self._perform_checkpoint() # Validate existing connections and recreate broken ones temp_conns = [] initial_size = self.pool.qsize() while not self.pool.empty(): try: conn = self.pool.get_nowait() except queue.Empty: break _ = self.connection_ids.get(id(conn), -1) if self._is_connection_healthy(conn): temp_conns.append(conn) else: self._safe_close_and_forget(conn) try: temp_conns.append(self._recreate_connection()) except Exception as e: LOGGER.error( "Failed to recreate connection in keep-alive: %s", str(e) ) if len(temp_conns) < initial_size: LOGGER.warning( "Lost %d connections during keep-alive", initial_size - len(temp_conns), ) while ( len(temp_conns) < self.pool_size and self._keep_alive_running.is_set() ): try: temp_conns.append(self._recreate_connection()) except Exception as e: LOGGER.error( "Failed to restore connection in keep-alive: %s", str(e) ) break for conn in temp_conns: try: self.pool.put_nowait(conn) except queue.Full: self._safe_close_and_forget(conn) def _perform_checkpoint(self): """Run a WAL checkpoint to keep file sizes bounded.""" checkpoint_conn = None try: checkpoint_conn = ( sqlite3.connect(self.db_path, check_same_thread=False) if not self.encryption_key else sqlcipher.connect(self.db_path, check_same_thread=False) ) if self.encryption_key: checkpoint_conn.execute(f"PRAGMA key = '{self.encryption_key}'") # Skip migration for checkpoint connection - not needed checkpoint_conn.execute(PRAGMA_CIPHER_COMPAT) cursor = checkpoint_conn.cursor() cursor.execute("PRAGMA wal_checkpoint(TRUNCATE)") except Exception as e: if ":memory:" in self.db_path: LOGGER.debug( "Keep-alive WAL checkpoint failed (in-memory db): %s", str(e) ) else: LOGGER.error("Keep-alive WAL checkpoint failed: %s", str(e)) finally: if checkpoint_conn: try: checkpoint_conn.close() except Exception: pass def _is_connection_healthy(self, conn) -> bool: try: cursor = conn.cursor() cursor.execute("SELECT 1") cursor.execute("BEGIN") cursor.execute("ROLLBACK") return True except Exception: return False def _recreate_connection(self): new_conn = self._create_connection() return new_conn def _safe_close_and_forget(self, conn): try: conn.close() except Exception: pass try: del self.connection_ids[id(conn)] except Exception: pass def _create_connection(self): try: if self.encryption_key: if sqlcipher is None: raise ImportError( "sqlcipher3 is required for encryption but not installed." ) conn = sqlcipher.connect( self.db_path, timeout=self.busy_timeout, check_same_thread=False ) try: conn.execute(f"PRAGMA key = '{self.encryption_key}'") # Set compatibility first conn.execute(PRAGMA_CIPHER_COMPAT) # Try to set WAL mode first (must be done before any transactions) conn.execute("PRAGMA journal_mode = WAL") conn.execute("PRAGMA foreign_keys = ON;") # Now test if we can read the database cursor = conn.cursor() cursor.execute("SELECT count(*) FROM sqlite_master") except Exception as e: conn.close() LOGGER.error("SQLCipher initialization failed: %s", str(e)) raise else: conn = sqlite3.connect( self.db_path, timeout=self.busy_timeout, check_same_thread=False ) conn.execute(f"PRAGMA journal_mode = {self.journal_mode}") conn.execute(f"PRAGMA locking_mode = {self.locking_mode}") conn.execute(f"PRAGMA synchronous = {self.synchronous}") conn.execute( "PRAGMA cache_size = -2000" if self.shared_cache else "PRAGMA cache_size = -1000" ) conn.execute("PRAGMA foreign_keys = ON;") conn.execute("PRAGMA wal_autocheckpoint = 1000") conn_id = self.connection_count self.connection_ids[id(conn)] = conn_id self.connection_count += 1 return conn except Exception as e: LOGGER.error("Failed to create database connection: %s", str(e)) raise DatabaseError( code=DatabaseErrorCode.CONNECTION_ERROR, message="Failed to create database connection", actual_error=str(e), )
[docs] def get_connection(self, timeout: float = 30.0): """Get a connection from the pool.""" with self.lock: try: start_time = time.time() while time.time() - start_time < timeout: try: conn = self.pool.get(block=False) _ = self.connection_ids.get(id(conn), -1) if self._is_connection_healthy(conn): return conn # unhealthy: close and replace self._safe_close_and_forget(conn) try: self.pool.put(self._recreate_connection()) except Exception as e: LOGGER.error("Failed to recreate connection: %s", str(e)) continue except queue.Empty: time.sleep(0.1) LOGGER.error("Connection pool exhausted after %d seconds", timeout) raise DatabaseError( code=DatabaseErrorCode.CONNECTION_POOL_EXHAUSTED, message=f"Connection pool exhausted after {timeout} seconds", ) except Exception as e: LOGGER.error("Failed to retrieve connection from pool: %s", str(e)) raise DatabaseError( code=DatabaseErrorCode.CONNECTION_ERROR, message="Failed to retrieve connection from pool", actual_error=str(e), )
[docs] def return_connection(self, conn): """Return a connection to the pool.""" with self.lock: try: if self._is_connection_healthy(conn): self.pool.put(conn) else: self._safe_close_and_forget(conn) self.pool.put(self._recreate_connection()) LOGGER.debug( "Connection ID=%d returned to pool. Pool size: %d/%d", self.connection_ids.get(id(conn), -1), self.pool.qsize(), self.pool_size, ) except Exception: self._safe_close_and_forget(conn) try: new_conn = self._recreate_connection() self.pool.put(new_conn) except Exception as e: LOGGER.error("Failed to recreate connection for pool: %s", str(e)) raise DatabaseError( code=DatabaseErrorCode.CONNECTION_ERROR, message="Failed to recreate connection for pool", actual_error=str(e), )
[docs] def drain_all_connections(self): """Drain all connections from the pool.""" connections = [] with self.lock: while not self.pool.empty(): try: conn = self.pool.get_nowait() connections.append(conn) except queue.Empty: break return connections
[docs] def close(self): """Close the connection pool.""" import os # Allow configuring close timeout for tests (default 15s for production) close_timeout = float(os.environ.get("SQLITE_CLOSE_TIMEOUT", "15.0")) with self.lock: self._keep_alive_running.clear() self.keep_alive_thread.join(timeout=close_timeout) checkpoint_conn = None try: checkpoint_conn = ( sqlite3.connect(self.db_path, check_same_thread=False) if not self.encryption_key else sqlcipher.connect(self.db_path, check_same_thread=False) ) if self.encryption_key: checkpoint_conn.execute(f"PRAGMA key = '{self.encryption_key}'") # Skip migration for checkpoint connection - not needed checkpoint_conn.execute(PRAGMA_CIPHER_COMPAT) checkpoint_conn.execute("PRAGMA cipher_memory_security = OFF") cursor = checkpoint_conn.cursor() cursor.execute("PRAGMA wal_checkpoint(TRUNCATE)") except Exception: pass finally: if checkpoint_conn: try: checkpoint_conn.close() except Exception: pass while not self.pool.empty(): try: conn = self.pool.get_nowait() _ = self.connection_ids.get(id(conn), -1) # Track for debugging try: conn.close() del self.connection_ids[id(conn)] except Exception: pass except queue.Empty: break self.connection_ids.clear()