"""This module contains the `LAMMPSReaderMPI` class."""
# pylint: disable=no-name-in-module, broad-except
from dataclasses import dataclass, field
from pathlib import Path
from types import TracebackType
from typing import Any, Generator, TextIO
import numpy as np
from mpi4py import MPI
from irradiapy.utils.mpi import (
MPIExceptionHandlerMixin,
MPITagAllocator,
mpi_safe_method,
mpi_subdomains_decomposition,
)
[docs]
@dataclass
class LAMMPSReaderMPI(MPIExceptionHandlerMixin):
"""A class to read data from a LAMMPS dump file in parallel using MPI.
Note
----
Assumed orthogonal simulation box.
Note
----
Rank 0 reads each timestep one by one, then scatters strings of atom data
to all ranks, which build local numpy structured arrays.
Attributes
----------
file_path : Path
The path to the LAMMPS dump file.
encoding : str, optional (default="utf-8")
The file encoding.
comm : MPI.Comm, optional (default=mpi4py.MPI.COMM_WORLD)
The MPI communicator.
Yields
------
dict[str, Any]
A dictionary containing the timestep data with keys:
'time' (optional), 'timestep', 'boundary', 'xlo', 'xhi',
'ylo', 'yhi', 'zlo', 'zhi', and 'atoms' (as a numpy structured array).
"""
file_path: Path
encoding: str = "utf-8"
comm: MPI.Comm = field(default_factory=lambda: MPI.COMM_WORLD)
__file: TextIO = field(default=None, init=False)
__rank: int = field(init=False)
__commsize: int = field(init=False)
__comm_tag: int = field(default_factory=MPITagAllocator.get_tag, init=False)
__nx: int = field(init=False)
__ny: int = field(init=False)
__nz: int = field(init=False)
__natoms: int = field(init=False)
def __post_init__(self) -> None:
self.__rank = self.comm.Get_rank()
self.__commsize = self.comm.Get_size()
self.__nx, self.__ny, self.__nz = mpi_subdomains_decomposition(self.__commsize)
ix = self.__rank % self.__nx
iy = (self.__rank // self.__nx) % self.__ny
iz = self.__rank // (self.__nx * self.__ny)
self._domain_index = (ix, iy, iz)
self.__file = None
if self.__rank == 0:
try:
self.__file = open(self.file_path, "r", encoding=self.encoding)
except Exception:
self._handle_exception()
def __enter__(self) -> "LAMMPSReaderMPI":
return self
def __del__(self) -> None:
if self.__rank == 0 and self.__file is not None:
self.__file.close()
def __exit__(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
exc_traceback: TracebackType | None = None,
) -> bool:
"""Exits the context manager."""
if self.__rank == 0 and self.__file is not None:
self.__file.close()
return False
def __get_dtype(
self, file: TextIO
) -> tuple[list[str], list[type[int | float]], np.dtype]:
items = file.readline().split()[2:]
types = [
np.int64 if it in ("id", "type", "element", "size") else np.float64
for it in items
]
return items, types, np.dtype(list(zip(items, types)))
def __process_header(self, file: TextIO) -> dict[str, Any]:
data: dict[str, Any] = {}
line = file.readline()
if not line:
return {}
if line.strip() == "ITEM: TIME":
data["time"] = float(file.readline())
file.readline()
else:
# rewind if no time item
file.seek(file.tell() - len(line))
data["timestep"] = int(file.readline())
file.readline()
self.__natoms = int(file.readline())
data["boundary"] = file.readline().split()[3:]
bounds = []
for _ in range(3):
b = file.readline().split()
bounds.append(b)
data["xlo"], data["xhi"] = map(float, bounds[0][:2])
data["ylo"], data["yhi"] = map(float, bounds[1][:2])
data["zlo"], data["zhi"] = map(float, bounds[2][:2])
return data
@mpi_safe_method
def __iter__(self) -> Generator[dict[str, Any], None, None]:
while True:
# header broadcast
data = self.comm.bcast(
self.__process_header(self.__file) if self.__rank == 0 else None, root=0
)
if data is None or not data:
break
# dtype broadcast
items, types, dtype = self.comm.bcast(
(
self.__get_dtype(self.__file)
if self.__rank == 0
else (None, None, None)
),
root=0,
)
data.update({"items": items, "types": types, "dtype": dtype})
# calculate raw line counts
counts = [
(self.__natoms // self.__commsize)
+ (1 if i < (self.__natoms % self.__commsize) else 0)
for i in range(self.__commsize)
]
# distribute chunks
if self.__rank == 0:
chunks = []
for cnt in counts:
chunk = [self.__file.readline().split() for _ in range(cnt)]
chunks.append(chunk)
raw = chunks[0]
for r in range(1, self.__commsize):
self.comm.send(chunks[r], dest=r, tag=self.__comm_tag)
else:
raw = self.comm.recv(source=0, tag=self.__comm_tag)
# build structured array
arr = np.empty(len(raw), dtype=dtype)
for i, fields in enumerate(raw):
for j, key in enumerate(items):
arr[key][i] = types[j](fields[j])
# subdomain info: indices and physical bounds
xlo, xhi = data["xlo"], data["xhi"]
ylo, yhi = data["ylo"], data["yhi"]
zlo, zhi = data["zlo"], data["zhi"]
dx = (xhi - xlo) / self.__nx
dy = (yhi - ylo) / self.__ny
dz = (zhi - zlo) / self.__nz
ix, iy, iz = self._domain_index
data["subdomain_index"] = (ix, iy, iz)
data["subdomain_bounds"] = {
"xlo": xlo + ix * dx,
"xhi": xlo + (ix + 1) * dx,
"ylo": ylo + iy * dy,
"yhi": ylo + (iy + 1) * dy,
"zlo": zlo + iz * dz,
"zhi": zlo + (iz + 1) * dz,
}
# attach atoms
data["subdomain_atoms"] = arr
data["subdomain_natoms"] = len(arr)
yield data
if self.__rank == 0 and self.__file is not None:
self.__file.close()
[docs]
@mpi_safe_method
def close(self) -> None:
"""Closes the file associated with this reader."""
if self.__rank == 0 and not self.__file.closed:
self.__file.close()