"""Database store module for managing different database backends."""
import asyncio
import importlib
import inspect
import json
# anext is a builtin in Python 3.10+
import logging
import threading
from collections.abc import AsyncGenerator, AsyncIterator
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Sequence
from .db_types import Entry, EntryList
from .error import DBStoreError, DBStoreErrorCode
from .interfaces import AbstractDatabaseSession, AbstractDatabaseStore, DatabaseBackend
# Logging setup
LOGGER = logging.getLogger(__name__)
# Registry for backends with thread safety
_backend_registry: dict[str, DatabaseBackend] = {}
_registry_lock = threading.Lock()
_BACKEND_REGISTRATION_IMPORT = ".databases.backends.backend_registration"
[docs]
def register_backend(db_type: str, backend: DatabaseBackend):
"""Register a backend for a given database type."""
LOGGER.debug(f"Registering backend for db_type={db_type}")
_backend_registry[db_type] = backend
[docs]
class Scan(AsyncIterator):
"""Async iterator for database scanning."""
def __init__(
self,
store: "DBStore",
profile: Optional[str],
category: str | bytes,
tag_filter: str | dict = None,
offset: int = None,
limit: int = None,
order_by: Optional[str] = None,
descending: bool = False,
):
"""Initialize DBStoreScan with scan parameters."""
self._store = store
self._profile = profile
self._category = category
self._tag_filter = tag_filter
self._offset = offset
self._limit = limit
self._order_by = order_by
self._descending = descending
self._generator = None
# Create a ThreadPoolExecutor for running synchronous tasks
self._executor = ThreadPoolExecutor(max_workers=1)
# Check if the underlying scan method is async
self._is_async = inspect.iscoroutinefunction(
self._store._db.scan
) or inspect.isasyncgenfunction(self._store._db.scan)
async def __anext__(self) -> Entry:
"""Get next item from async scan."""
if self._generator is None:
if self._is_async:
# For async backends (e.g., PostgreSQL), get async generator
self._generator = self._store._db.scan(
self._profile,
self._category,
self._tag_filter,
self._offset,
self._limit,
self._order_by,
self._descending,
)
else:
# For sync backends (e.g., SQLite), run in executor
def create_generator() -> AsyncIterator[Entry]:
return self._store._db.scan(
self._profile,
self._category,
self._tag_filter,
self._offset,
self._limit,
self._order_by,
self._descending,
)
loop = asyncio.get_running_loop()
self._generator = await loop.run_in_executor(
self._executor, create_generator
)
if self._is_async:
# Handle async generators
try:
return await anext(self._generator) # noqa: F821
except StopAsyncIteration:
LOGGER.error("StopAsyncIteration in __anext__")
raise
else:
# Handle sync generators using the executor
def get_next() -> Entry | None:
try:
return next(self._generator)
except StopIteration:
return None
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(self._executor, get_next)
if result is None:
raise StopAsyncIteration
return result
def __del__(self) -> None:
"""Clean up resources."""
# Shut down the executor to clean up resources
self._executor.shutdown(wait=False)
[docs]
class ScanKeyset(AsyncIterator):
"""Keyset-based scan iterator."""
def __init__(
self,
store: "DBStore",
profile: Optional[str],
category: str | bytes,
tag_filter: str | dict = None,
last_id: Optional[int] = None,
limit: int = None,
order_by: Optional[str] = None,
descending: bool = False,
):
"""Initialize the ScanKeyset iterator with filters and sorting."""
LOGGER.debug(
f"ScanKeyset initialized with store={store}, "
f"profile={profile}, category={category}, "
f"tag_filter={tag_filter}, last_id={last_id}, "
f"limit={limit}, order_by={order_by}, "
f"descending={descending}"
)
self._store = store
self._profile = profile
self._category = category if isinstance(category, str) else category.decode()
self._tag_filter = tag_filter
self._last_id = last_id
self._limit = limit
self._order_by = order_by
self._descending = descending
self._executor = ThreadPoolExecutor(max_workers=1)
self._generator = None
# Check if scan_keyset is a coroutine or async generator
self._is_async = inspect.iscoroutinefunction(
self._store._db.scan_keyset
) or inspect.isasyncgenfunction(self._store._db.scan_keyset)
async def __anext__(self) -> Entry:
"""Get next item from async keyset scan."""
if self._generator is None:
if self._is_async:
# For async backends (e.g., PostgreSQL), get async generator
self._generator = self._store._db.scan_keyset(
self._profile,
self._category,
self._tag_filter,
self._last_id,
self._limit,
self._order_by,
self._descending,
)
else:
# For sync backends (e.g., SQLite), run scan_keyset in executor
def create_generator() -> AsyncGenerator[Entry, None]:
return self._store._db.scan_keyset(
self._profile,
self._category,
self._tag_filter,
self._last_id,
self._limit,
self._order_by,
self._descending,
)
loop = asyncio.get_running_loop()
self._generator = await loop.run_in_executor(
self._executor, create_generator
)
if self._is_async:
# Handle async generators
try:
return await anext(self._generator) # noqa: F821
except StopAsyncIteration:
LOGGER.error("StopAsyncIteration in __anext__")
raise
else:
# Handle sync generators using the executor
def get_next() -> Entry | None:
try:
return next(self._generator)
except StopIteration:
return None
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(self._executor, get_next)
if result is None:
raise StopAsyncIteration
return result
def __del__(self) -> None:
"""Clean up resources."""
self._executor.shutdown(wait=False)
[docs]
async def fetch_all(self) -> Sequence[Entry]:
"""Perform the action."""
rows = []
async for row in self:
rows.append(row)
return rows
[docs]
class DBStore:
"""Database store class."""
def __init__(
self, db: AbstractDatabaseStore, uri: str, release_number: str = "release_0"
):
"""Initialize DBStore."""
LOGGER.debug("Store initialized (release_number=%s)", release_number)
self._db = db
self._uri = uri
self._release_number = release_number
self._opener: Optional[DBOpenSession] = None
[docs]
@classmethod
def generate_raw_key(cls, seed: str | bytes | None = None) -> str:
"""Perform the action."""
LOGGER.debug("generate_raw_key called (seed_provided=%s)", bool(seed))
from . import bindings
return bindings.generate_raw_key(seed)
@property
def handle(self):
"""Perform the action."""
return id(self)
@property
def uri(self) -> str:
"""Perform the action."""
return self._uri
@property
def release_number(self) -> str:
"""Perform the action."""
return self._release_number
[docs]
@classmethod
async def provision(
cls,
uri: str,
key_method: str = None,
pass_key: str = None,
*,
profile: str = None,
recreate: bool = False,
release_number: str = "release_0",
schema_config: Optional[str] = None,
config: Optional[dict] = None,
) -> "DBStore":
"""Provision a new database store with specified release and schema."""
LOGGER.debug(
"provision called (recreate=%s, release_number=%s)",
recreate,
release_number,
)
# Thread-safe backend registration
with _registry_lock:
if not _backend_registry: # Register backends if not already done
backend_registration = importlib.import_module(
_BACKEND_REGISTRATION_IMPORT, package=__package__
)
backend_registration.register_backends()
db_type = uri.split(":")[0]
backend = _backend_registry.get(db_type)
if not backend:
raise DBStoreError(
DBStoreErrorCode.BACKEND, f"Unsupported database type: {db_type}"
)
try:
if inspect.iscoroutinefunction(backend.provision):
db = await backend.provision(
uri,
key_method,
pass_key,
profile,
recreate,
release_number,
schema_config,
config=config,
)
else:
db = await asyncio.to_thread(
backend.provision,
uri,
key_method,
pass_key,
profile,
recreate,
release_number,
schema_config,
config=config,
)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("provision error: %s", type(e).__name__)
raise backend.translate_error(e)
return cls(db, uri, release_number)
[docs]
@classmethod
async def open(
cls,
uri: str,
key_method: str = None,
pass_key: str = None,
*,
profile: str = None,
schema_migration: Optional[bool] = None,
target_schema_release_number: Optional[str] = None,
config: Optional[dict] = None,
) -> "DBStore":
"""Perform the action."""
LOGGER.debug(
"open called (schema_migration=%s, target_schema_release_number=%s)",
schema_migration,
target_schema_release_number,
)
# Thread-safe backend registration
with _registry_lock:
if not _backend_registry: # Register backends if not already done
backend_registration = importlib.import_module(
_BACKEND_REGISTRATION_IMPORT, package=__package__
)
backend_registration.register_backends()
db_type = uri.split(":")[0]
backend = _backend_registry.get(db_type)
if not backend:
raise DBStoreError(
DBStoreErrorCode.BACKEND, f"Unsupported database type: {db_type}"
)
try:
if inspect.iscoroutinefunction(backend.open):
db = await backend.open(
uri,
key_method,
pass_key,
profile,
schema_migration,
target_schema_release_number,
config=config,
)
else:
db = await asyncio.to_thread(
backend.open,
uri,
key_method,
pass_key,
profile,
schema_migration,
target_schema_release_number,
config=config,
)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("open error: %s", type(e).__name__)
raise backend.translate_error(e)
return cls(db, uri, db.release_number)
[docs]
@classmethod
async def remove(
cls, uri: str, release_number: str = "release_0", config: Optional[dict] = None
) -> bool:
"""Remove the database store."""
LOGGER.debug("remove called (release_number=%s)", release_number)
# Thread-safe backend registration
with _registry_lock:
if not _backend_registry: # Register backends if not already done
backend_registration = importlib.import_module(
_BACKEND_REGISTRATION_IMPORT, package=__package__
)
backend_registration.register_backends()
db_type = uri.split(":")[0]
backend = _backend_registry.get(db_type)
if not backend:
raise DBStoreError(
DBStoreErrorCode.BACKEND, f"Unsupported database type: {db_type}"
)
try:
if inspect.iscoroutinefunction(backend.remove):
return await backend.remove(uri, config=config)
else:
return await asyncio.to_thread(backend.remove, uri, config=config)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("remove error: %s", type(e).__name__)
raise backend.translate_error(e)
[docs]
async def initialize(self) -> None:
"""Initialize the database store."""
LOGGER.debug("initialize called")
try:
if inspect.iscoroutinefunction(self._db.initialize):
await self._db.initialize()
else:
await asyncio.to_thread(self._db.initialize)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("initialize error: %s", type(e).__name__)
raise self._db.translate_error(e)
[docs]
async def create_profile(self, name: str = None) -> str:
"""Perform the action."""
LOGGER.debug(f"create_profile called with name={name}")
try:
return await self._db.create_profile(name)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("create_profile error: %s", str(e))
raise self._db.translate_error(e)
[docs]
async def get_profile_name(self) -> str:
"""Perform the action."""
LOGGER.debug("get_profile_name called")
try:
return await self._db.get_profile_name()
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("get_profile_name error: %s", str(e))
raise self._db.translate_error(e)
[docs]
async def remove_profile(self, name: str) -> bool:
"""Perform the action."""
LOGGER.debug(f"remove_profile called with name={name}")
try:
return await self._db.remove_profile(name)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("remove_profile error: %s", str(e))
raise self._db.translate_error(e)
[docs]
async def rekey(self, key_method: str = None, pass_key: str = None) -> None:
"""Perform the action."""
LOGGER.debug(f"rekey called with key_method={key_method}, pass_key=***")
try:
await self._db.rekey(key_method, pass_key)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("rekey error: %s", str(e))
raise self._db.translate_error(e)
[docs]
def scan(
self,
category: str,
tag_filter: str | dict = None,
offset: int = None,
limit: int = None,
profile: str = None,
order_by: Optional[str] = None,
descending: bool = False,
) -> Scan:
"""Scan the database for entries matching the criteria."""
LOGGER.debug(
f"scan called with category={category}, tag_filter={tag_filter}, "
f"offset={offset}, "
f"limit={limit}, profile={profile}, order_by={order_by}, "
f"descending={descending}"
)
return Scan(
self, profile, category, tag_filter, offset, limit, order_by, descending
)
[docs]
def scan_keyset(
self,
category: str,
tag_filter: str | dict = None,
last_id: Optional[int] = None,
limit: int = None,
profile: str = None,
order_by: Optional[str] = None,
descending: bool = False,
) -> ScanKeyset:
"""Scan the database using keyset pagination."""
LOGGER.debug(
f"scan_keyset called with category={category}, "
f"tag_filter={tag_filter}, last_id={last_id}, "
f"limit={limit}, profile={profile}, order_by={order_by}, "
f"descending={descending}"
)
return ScanKeyset(
self, profile, category, tag_filter, last_id, limit, order_by, descending
)
[docs]
def session(self, profile: str = None) -> "DBOpenSession":
"""Perform the action."""
LOGGER.debug(f"session called with profile={profile}")
return DBOpenSession(self._db, profile, False, self._release_number)
[docs]
def transaction(self, profile: str = None) -> "DBOpenSession":
"""Perform the action."""
LOGGER.debug(f"transaction called with profile={profile}")
return DBOpenSession(self._db, profile, True, self._release_number)
[docs]
async def close(self, *, remove: bool = False) -> bool:
"""Perform the action."""
LOGGER.debug(f"close called with remove={remove}")
try:
if self._db:
if inspect.iscoroutinefunction(self._db.close):
await self._db.close(remove=remove)
else:
await asyncio.to_thread(self._db.close, remove=remove)
self._db = None
LOGGER.debug("close completed")
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("close failed: %s", str(e))
raise DBStoreError(DBStoreErrorCode.UNEXPECTED, str(e)) from e
async def __aenter__(self) -> "DBStoreSession":
"""Enter async context manager."""
LOGGER.debug("__aenter__ called")
if not self._opener:
self._opener = DBOpenSession(self._db, None, False, self._release_number)
return await self._opener.__aenter__()
async def __aexit__(self, exc_type, exc, tb):
"""Exit async context manager."""
LOGGER.debug(f"__aexit__ called with exc_type={exc_type}, exc={exc}, tb={tb}")
opener = self._opener
self._opener = None
return await opener.__aexit__(exc_type, exc, tb)
def __repr__(self) -> str:
"""Magic method description."""
return f"<Store(handle={self.handle})>"
[docs]
class DBStoreSession:
"""Database store session class."""
def __init__(self, db_session: AbstractDatabaseSession, is_txn: bool):
"""Initialize DBStoreSession."""
LOGGER.debug(f"Session initialized with db_session={db_session}, is_txn={is_txn}")
self._db_session = db_session
self._is_txn = is_txn
@property
def is_transaction(self) -> bool:
"""Check if the session is a transaction."""
return self._is_txn
@property
def handle(self):
"""Get a unique identifier for the session."""
return id(self)
[docs]
async def count(self, category: str, tag_filter: str | dict = None) -> int:
"""Perform the action."""
LOGGER.debug(f"count called with category={category}, tag_filter={tag_filter}")
try:
return await self._db_session.count(category, tag_filter)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("count error: %s", str(e))
raise self._db_session.translate_error(e)
[docs]
async def fetch(
self, category: str, name: str, *, for_update: bool = False
) -> Optional[Entry]:
"""Perform the action."""
LOGGER.debug(
f"fetch called with category={category}, name={name}, for_update={for_update}"
)
try:
return await self._db_session.fetch(
category, name, tag_filter=None, for_update=for_update
)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("fetch error: %s", str(e))
raise self._db_session.translate_error(e)
[docs]
async def fetch_all(
self,
category: str,
tag_filter: str | dict = None,
limit: int = None,
for_update: bool = False,
order_by: Optional[str] = None,
descending: bool = False,
) -> EntryList:
"""Perform the action."""
LOGGER.debug(
f"fetch_all called with category={category}, "
f"tag_filter={tag_filter}, limit={limit}, "
f"for_update={for_update}, order_by={order_by}, "
f"descending={descending}"
)
try:
entries = await self._db_session.fetch_all(
category, tag_filter, limit, for_update, order_by, descending
)
return EntryList(entries)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("fetch_all error: %s", str(e))
raise self._db_session.translate_error(e)
[docs]
async def insert(
self,
category: str,
name: str,
value: str | bytes = None,
tags: dict = None,
expiry_ms: int = None,
value_json=None,
) -> None:
"""Perform the action."""
LOGGER.debug(
f"insert called with category={category}, name={name}, "
f"value={value}, "
f"tags={tags}, expiry_ms={expiry_ms}, value_json={value_json}"
)
try:
if value is None and value_json is not None:
value = json.dumps(value_json)
await self._db_session.insert(category, name, value, tags, expiry_ms)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("insert error: %s", str(e))
raise self._db_session.translate_error(e)
[docs]
async def replace(
self,
category: str,
name: str,
value: str | bytes = None,
tags: dict = None,
expiry_ms: int = None,
value_json=None,
) -> None:
"""Perform the action."""
LOGGER.debug(
f"replace called with category={category}, name={name}, "
f"value={value}, "
f"tags={tags}, expiry_ms={expiry_ms}, value_json={value_json}"
)
try:
if value is None and value_json is not None:
value = json.dumps(value_json)
await self._db_session.replace(category, name, value, tags, expiry_ms)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("replace error: %s", str(e))
raise self._db_session.translate_error(e)
[docs]
async def remove(self, category: str, name: str) -> None:
"""Perform the action."""
LOGGER.debug(f"remove called with category={category}, name={name}")
try:
await self._db_session.remove(category, name)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("remove error: %s", str(e))
raise self._db_session.translate_error(e)
[docs]
async def remove_all(self, category: str, tag_filter: str | dict = None) -> int:
"""Perform the action."""
LOGGER.debug(
f"remove_all called with category={category}, tag_filter={tag_filter}"
)
try:
return await self._db_session.remove_all(category, tag_filter)
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("remove_all error: %s", str(e))
raise self._db_session.translate_error(e)
[docs]
async def commit(self) -> None:
"""Perform the action."""
LOGGER.debug("commit called")
if not self._is_txn:
raise DBStoreError(DBStoreErrorCode.WRAPPER, "Session is not a transaction")
try:
await self._db_session.commit()
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("commit error: %s", str(e))
raise self._db_session.translate_error(e)
[docs]
async def rollback(self) -> None:
"""Perform the action."""
LOGGER.debug("rollback called")
if not self._is_txn:
raise DBStoreError(DBStoreErrorCode.WRAPPER, "Session is not a transaction")
try:
await self._db_session.rollback()
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("rollback error: %s", str(e))
raise self._db_session.translate_error(e)
[docs]
async def close(self) -> None:
"""Perform the action."""
LOGGER.debug("close called")
try:
await self._db_session.close()
except asyncio.CancelledError:
raise
except Exception as e:
LOGGER.error("close error: %s", str(e))
raise self._db_session.translate_error(e)
def __repr__(self) -> str:
"""Magic method description."""
return f"<Session(handle={self.handle}, is_transaction={self._is_txn})>"
[docs]
class DBOpenSession:
"""Database open session class."""
def __init__(
self,
db: AbstractDatabaseStore,
profile: Optional[str],
is_txn: bool,
release_number: str,
):
"""Initialize DBOpenSession."""
LOGGER.debug(
f"OpenSession initialized with db={db}, profile={profile}, "
f"is_txn={is_txn}, release_number={release_number}"
)
self._db = db
self._profile = profile
self._is_txn = is_txn
self._release_number = release_number
self._session: Optional[DBStoreSession] = None
@property
def is_transaction(self) -> bool:
"""Perform the action."""
return self._is_txn
async def _open(self) -> DBStoreSession:
"""Perform the action."""
LOGGER.debug("_open called")
if self._session:
raise DBStoreError(DBStoreErrorCode.WRAPPER, "Session already opened")
method = self._db.transaction if self._is_txn else self._db.session
self._db_session = (
await method(self._profile)
if inspect.iscoroutinefunction(method)
else method(self._profile)
)
await self._db_session.__aenter__()
self._session = DBStoreSession(self._db_session, self._is_txn)
return self._session
def __await__(self) -> DBStoreSession:
"""Magic method description."""
return self._open().__await__()
async def __aenter__(self) -> DBStoreSession:
"""Magic method description."""
LOGGER.debug("__aenter__ called")
self._session = await self._open()
return self._session
async def __aexit__(self, exc_type, exc, tb):
"""Magic method description."""
LOGGER.debug(f"__aexit__ called with exc_type={exc_type}, exc={exc}, tb={tb}")
session = self._session
self._session = None
if self._is_txn and exc_type is None:
await session.commit()
await session.close()