Source code for irradiapy.io.bzip2lammpswritermpi

"""This module contains the `BZIP2LAMMPSWriterMPI` class."""

# pylint: disable=no-name-in-module, broad-except

import bz2
from dataclasses import dataclass, field
from pathlib import Path
from types import TracebackType
from typing import Any, TextIO

import numpy.typing as npt
from mpi4py import MPI

from irradiapy import config
from irradiapy.utils.mpi import (
    MPIExceptionHandlerMixin,
    MPITagAllocator,
    mpi_safe_method,
)


[docs] @dataclass class BZIP2LAMMPSWriterMPI(MPIExceptionHandlerMixin): """A class to write data like a LAMMPS dump bzip2 compressed file in parallel using MPI. Note ---- Assumed orthogonal simulation box. Note ---- All ranks compute their local subdomain of atoms, then rank 0 collects and writes the data. Parameters ---------- file_path : Path Output .bz2 path. mode : str File mode for the *container* file opened on rank 0 (default: 'wb'). encoding : str, optional (default=irradiapy.config.ENCODING) The file encoding. comm : MPI.Comm, optional (default=mpi4py.MPI.COMM_WORLD) The MPI communicator. compresslevel : int, optional (default=9) Compression level for bzip2. int_format : str, optional (default=irradiapy.config.INT_FORMAT) The format for integers. float_format : str, optional (default=irradiapy.config.FLOAT_FORMAT) The format for floats. excluded_items : list[str], optional (default=irradiapy.config.EXCLUDED_ITEMS) Atom fields to exclude from output. """ file_path: Path mode: str = "wb" encoding: str = field(default_factory=lambda: config.ENCODING) comm: MPI.Comm = field(default_factory=lambda: MPI.COMM_WORLD) compresslevel: int = 9 int_format: str = field(default_factory=lambda: config.INT_FORMAT) float_format: str = field(default_factory=lambda: config.FLOAT_FORMAT) excluded_items: list[str] = field(default_factory=lambda: config.EXCLUDED_ITEMS) __file: TextIO | None = field(default=None, init=False) __rank: int = field(init=False) __size: int = field(init=False) __comm_tag: int = field(default_factory=MPITagAllocator.get_tag, init=False) def __post_init__(self) -> None: """Opens the file associated with this writer.""" self.__rank = self.comm.Get_rank() self.__size = self.comm.Get_size() self.__file = None if self.__rank == 0: try: self.__file = open(self.file_path, self.mode) except Exception: self._handle_exception() def __enter__(self) -> "BZIP2LAMMPSWriterMPI": return self def __del__(self) -> None: if self.__rank == 0 and self.__file is not None: self.__file.close() def __exit__( self, exc_type: type[BaseException] | None = None, exc_value: BaseException | None = None, exc_traceback: TracebackType | None = None, ) -> bool: if self.__rank == 0 and self.__file is not None: self.__file.close() return False def __atoms_rank_to_string( self, atoms_rank: npt.NDArray, field_names: list[str], formatters: list[str] ) -> str: """Converts the atoms_rank array to a formatted string. Parameters ---------- atoms_rank : npt.NDArray The atoms_rank array to be converted. field_names : list[str] The names of the fields in the structured array. formatters : list[str] The format strings for each field. Returns ------- str A formatted string representation of the atoms_rank array. """ lines_chunk = "\n".join( " ".join( fmt % atom[field_name] for fmt, field_name in zip(formatters, field_names) ) for atom in atoms_rank ) return lines_chunk
[docs] @mpi_safe_method def close(self) -> None: """Closes the file associated with this writer.""" if self.__rank == 0 and self.__file is not None and not self.__file.closed: self.__file.close()
[docs] @mpi_safe_method def write(self, data: dict[str, Any]) -> None: """Write the data to the file. Parameters ---------- data : dict[str, Any] A dictionary containing the data to be written. The keys should include "timestep", "boundary", "xlo", "xhi", "ylo", "yhi", "zlo", "zhi", and "atoms". Optional keys: "time". """ atoms = data["subdomain_atoms"] field_names = [f for f in atoms.dtype.names if f not in self.excluded_items] formatters: list[str] = [] for field_name in field_names: dtype = atoms.dtype[field_name] if dtype.kind == "i": formatters.append(self.int_format) elif dtype.kind == "f": formatters.append(self.float_format) else: formatters.append("%s") natoms = self.comm.allreduce(len(atoms), op=MPI.SUM) # Header if self.__rank == 0: header_lines = [] if data.get("time") is not None: header_lines.append(f"ITEM: TIME\n{self.float_format % data['time']}\n") header_lines.append( f"ITEM: TIMESTEP\n{self.int_format % data['timestep']}\n" ) header_lines.append(f"ITEM: NUMBER OF ATOMS\n{self.int_format % natoms}\n") header_lines.append(f"ITEM: BOX BOUNDS {' '.join(data['boundary'])}\n") header_lines.append( f"{self.float_format % data['xlo']} {self.float_format % data['xhi']}\n" ) header_lines.append( f"{self.float_format % data['ylo']} {self.float_format % data['yhi']}\n" ) header_lines.append( f"{self.float_format % data['zlo']} {self.float_format % data['zhi']}\n" ) header_lines.append(f"ITEM: ATOMS {' '.join(field_names)}\n") header_bytes = "".join(header_lines).encode(self.encoding) self.__file.write( bz2.compress(header_bytes, compresslevel=self.compresslevel) ) self.comm.Barrier() lines_chunk = self.__atoms_rank_to_string(atoms, field_names, formatters) if lines_chunk and not lines_chunk.endswith("\n"): lines_chunk += "\n" comp = bz2.compress( lines_chunk.encode(self.encoding), compresslevel=self.compresslevel ) if self.__rank == 0: self.__file.write(comp) for sender__rank in range(1, self.__size): self.comm.send(None, dest=sender__rank, tag=self.__comm_tag + 1) payload = self.comm.recv(source=sender__rank, tag=self.__comm_tag) self.__file.write(payload) else: self.comm.recv(source=0, tag=self.__comm_tag + 1) self.comm.send(comp, dest=0, tag=self.__comm_tag) self.comm.Barrier()