Source code for irradiapy.io.lammpsreadermpi

"""This module contains the `LAMMPSReaderMPI` class."""

# pylint: disable=no-name-in-module, broad-except

from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from types import TracebackType
from typing import Any, Dict, Generator, TextIO, Tuple, Type

import numpy as np
from mpi4py import MPI
from numpy import typing as npt

from irradiapy.utils.mpi import (
    MPIExceptionHandlerMixin,
    MPITagAllocator,
    mpi_safe_method,
    mpi_subdomains_decomposition,
)


[docs] @dataclass class LAMMPSReaderMPI(MPIExceptionHandlerMixin): """A class to read data from a LAMMPS dump file in parallel using MPI. Note ---- Assumed orthogonal simulation box. Attributes ---------- file_path : Path The path to the LAMMPS dump file. encoding : str, optional (default="utf-8") The file encoding. comm : MPI.Comm, optional (default=mpi4py.MPI.COMM_WORLD) The MPI communicator. """ file_path: Path encoding: str = "utf-8" file: TextIO = field(default=None, init=False) comm: MPI.Comm = field(default_factory=lambda: MPI.COMM_WORLD) __rank: int = field(init=False) __commsize: int = field(init=False) __comm_tag: int = field(default_factory=MPITagAllocator.get_tag, init=False) __nx: int = field(init=False) __ny: int = field(init=False) __nz: int = field(init=False) def __post_init__(self): self.__rank = self.comm.Get_rank() self.__commsize = self.comm.Get_size() self.__nx, self.__ny, self.__nz = mpi_subdomains_decomposition(self.__commsize) ix = self.__rank % self.__nx iy = (self.__rank // self.__nx) % self.__ny iz = self.__rank // (self.__nx * self.__ny) self._domain_index = (ix, iy, iz) self.file = None if self.__rank == 0: try: self.file = open(self.file_path, "r", encoding=self.encoding) except Exception: self._handle_exception() def __enter__(self) -> "LAMMPSReaderMPI": return self def __del__(self) -> None: if self.__rank == 0 and self.file is not None: self.file.close() def __exit__( self, exc_type: None | type[BaseException] = None, exc_value: None | BaseException = None, exc_traceback: None | TracebackType = None, ) -> bool: """Exits the context manager.""" if self.__rank == 0 and self.file is not None: self.file.close() return False def __get_dtype( self, file: TextIO ) -> Tuple[list[str], list[Type[int | float]], np.dtype]: items = file.readline().split()[2:] types = [ np.int64 if it in ("id", "type", "element", "size") else np.float64 for it in items ] if all(c in items for c in ("x", "y", "z")): items += ["xs", "ys", "zs"] types += [np.float64] * 3 return items, types, np.dtype(list(zip(items, types))) def __process_header(self, file: TextIO) -> Dict[str, Any]: data = defaultdict(None) line = file.readline() if not line: return {} if line.strip() == "ITEM: TIME": data["time"] = float(file.readline()) file.readline() else: # rewind if no time item file.seek(file.tell() - len(line)) data["timestep"] = int(file.readline()) file.readline() data["natoms"] = int(file.readline()) data["boundary"] = file.readline().split()[3:] bounds = [] for _ in range(3): b = file.readline().split() bounds.append(b) data["xlo"], data["xhi"] = map(float, bounds[0][:2]) data["ylo"], data["yhi"] = map(float, bounds[1][:2]) data["zlo"], data["zhi"] = map(float, bounds[2][:2]) return data @mpi_safe_method def __iter__(self) -> Generator[Tuple[Dict[str, Any], npt.NDArray], None, None]: while True: # header broadcast data = self.comm.bcast( self.__process_header(self.file) if self.__rank == 0 else None, root=0 ) if data is None or not data: break # dtype broadcast items, types, dtype = self.comm.bcast( self.__get_dtype(self.file) if self.__rank == 0 else (None, None, None), root=0, ) data.update({"items": items, "types": types, "dtype": dtype}) # calculate raw line counts natoms = data["natoms"] counts = [ (natoms // self.__commsize) + (1 if i < (natoms % self.__commsize) else 0) for i in range(self.__commsize) ] # distribute chunks if self.__rank == 0: chunks = [] for cnt in counts: chunk = [self.file.readline().split() for _ in range(cnt)] chunks.append(chunk) raw = chunks[0] for r in range(1, self.__commsize): self.comm.send(chunks[r], dest=r, tag=self.__comm_tag) else: raw = self.comm.recv(source=0, tag=self.__comm_tag) # build structured array arr = np.empty(len(raw), dtype=dtype) for i, fields in enumerate(raw): for j, key in enumerate(items): if key in ("xs", "ys", "zs"): continue arr[key][i] = types[j](fields[j]) # subdomain info: indices and physical bounds xlo, xhi = data["xlo"], data["xhi"] ylo, yhi = data["ylo"], data["yhi"] zlo, zhi = data["zlo"], data["zhi"] dx = (xhi - xlo) / self.__nx dy = (yhi - ylo) / self.__ny dz = (zhi - zlo) / self.__nz ix, iy, iz = self._domain_index data["subdomain_index"] = (ix, iy, iz) data["subdomain_bounds"] = { "xlo": xlo + ix * dx, "xhi": xlo + (ix + 1) * dx, "ylo": ylo + iy * dy, "yhi": ylo + (iy + 1) * dy, "zlo": zlo + iz * dz, "zhi": zlo + (iz + 1) * dz, } # normalize scaled coordinates based on true positions if all(c in items for c in ("xs", "ys", "zs")): arr["xs"] = (arr["x"] - xlo) / (xhi - xlo) arr["ys"] = (arr["y"] - ylo) / (yhi - ylo) arr["zs"] = (arr["z"] - zlo) / (zhi - zlo) # attach atoms data["atoms"] = arr data["natoms"] = len(arr) yield data if self.file: self.file.close()
[docs] @mpi_safe_method def close(self) -> None: """Closes the file associated with this writer.""" if self.__rank == 0 and not self.file.closed: self.file.close()