Source code for irradiapy.database

"""This module contains the `Database` class."""

import sqlite3
from dataclasses import dataclass
from pathlib import Path
from types import TracebackType
from typing import Any, Generator

import numpy as np
import numpy.typing as npt


[docs] @dataclass class Database(sqlite3.Connection): """A SQLite database with utility methods. Not intended to be used directly, instead use it to inherit other database classes. Parameters ---------- path : Path Path to the SQLite database file. """ path: Path def __post_init__(self) -> None: super().__init__(self.path) def __exit__( self, exc_type: type[BaseException] | None = None, exc_value: BaseException | None = None, exc_traceback: TracebackType | None = None, ) -> bool: """Exit the runtime context related to this object.""" self.close() return False
[docs] def optimize(self) -> None: """Optimize the SQLite database. This method performs two operations to optimize the database: 1. Executes the "PRAGMA optimize" command to analyze and optimize the database. 2. Executes the "VACUUM" command to rebuild the database file, repacking it into a minimal amount of disk space. """ cur = self.cursor() cur.execute("PRAGMA optimize") cur.execute("VACUUM") cur.close()
[docs] def table_exists(self, table_name: str) -> bool: """Checks if the given table exists in the database. Parameters ---------- table_name : str Table's name to check. Returns ------- bool Whether the table already exists or not. """ cur = self.cursor() cur.execute( ( "SELECT COUNT(*) FROM sqlite_master WHERE type='table'" f"AND name='{table_name}'" ) ) result = cur.fetchone()[0] cur.close() return bool(result)
[docs] def table_has_column(self, table_name: str, column_name: str) -> bool: """Checks if the given table has the specified column. Parameters ---------- table_name : str Table's name to check. column_name : str Column's name to check. Returns ------- bool Whether the column exists in the table or not. """ cur = self.cursor() cur.execute(f"PRAGMA table_info({table_name})") columns = [info[1] for info in cur.fetchall()] cur.close() return column_name in columns
[docs] def read( self, table: str, what: str = "*", conditions: str = "", ) -> Generator[tuple[Any, ...], None, None]: """Reads table data from the database as a generator. Parameters ---------- table : str Table to read from. what : str Columns to select. conditions : str Conditions to filter data. Yields ------ Generator[tuple[Any, ...], None, None] Data from the database. """ cur = self.cursor() try: cur.execute(f"SELECT {what} FROM {table} {conditions}") yield from cur finally: cur.close()
[docs] def read_chunk( self, table: str, what: str = "*", condition: str = "", chunksize: int = 10_000, ) -> Generator[tuple[Any, ...], None, None]: """Reads table data from the database as a generator in chunks. It might be faster than ``read`` for huge tables. Parameters ---------- table : str Table to read from. what : str Columns to select. condition : str Conditions to filter data. chunksize : int, optional (default=10_000) Number of rows to read per chunk. Yields ------ Generator[tuple[Any, ...], None, None] Data from the database. """ cur = self.cursor() try: cur.execute(f"SELECT {what} FROM {table} {condition}") while True: rows = cur.fetchmany(chunksize) if not rows: break yield from rows finally: cur.close()
[docs] def read_numpy( self, table: str, what: str, conditions: str = "", ) -> npt.NDArray: """Reads table data from the database as a NumPy structured array. Parameters ---------- table : str Table to read from. what : str Columns to select. conditions : str Conditions to filter data. Returns ------- npt.NDArray Data from the database as a NumPy structured array. """ cur = self.cursor() cur.execute(f"PRAGMA table_info({table})") columns_info = cur.fetchall() cur.close() columns_dtype = {} for column in columns_info: column_name = column[1] column_type = column[2].upper() if "INT" in column_type: columns_dtype[column_name] = "i8" elif ( "REAL" in column_type or "FLOAT" in column_type or "DOUBLE" in column_type ): columns_dtype[column_name] = "f8" else: # Default to string type columns_dtype[column_name] = "S256" if what.strip() == "*": column_names = [column[1] for column in columns_info] else: column_names = [] for name in what.split(","): name = name.strip() parts = name.split() # column AS alias if len(parts) >= 3 and parts[-2].lower() == "as": name = parts[-1] # column alias (implicit alias) elif len(parts) >= 2: name = parts[-1] # column else: name = parts[0] column_names.append(name) dtype_fields = [ (column_name, columns_dtype.get(column_name, "S256")) for column_name in column_names ] dtype = np.dtype(dtype_fields) array = np.fromiter( self.read(table=table, what=what, conditions=conditions), dtype=dtype, ) return array