Source code for irradiapy.utils.mpi

"""Utilities for MPI-related functionalities."""

# pylint: disable=no-name-in-module, broad-except

import sys
import traceback
from functools import wraps
from pathlib import Path

from mpi4py import MPI

from irradiapy.utils.math import repeated_prime_factors


[docs] def broadcast_variables(root: int, comm: MPI.Comm, *variables) -> list: """Broadcasts variables. Parameters ---------- root : int The rank of the process that will broadcast the variables. comm : MPI.Comm The MPI communicator. variables : tuple The variables to broadcast. Returns ------- list The broadcasted variables. """ return [comm.bcast(var, root=root) for var in variables]
[docs] def rm_file(path_file: Path, comm: MPI.Comm) -> None: """Remove a file.""" rank = comm.Get_rank() if rank == 0: path_file.unlink(missing_ok=True) comm.Barrier()
[docs] def cp_file(original: Path, target: Path, comm: MPI.Comm) -> None: """Copy a file from `original` to `target`, overwriting `target` if it exists.""" rank = comm.Get_rank() if rank == 0: print(f"\n\n\nCopying {original} to {target}") target.write_bytes(original.read_bytes()) comm.Barrier()
[docs] def mv_file(original: Path, target: Path, comm: MPI.Comm) -> None: """Move a file from `original` to `target`, overwriting `target` if it exists.""" rank = comm.Get_rank() if rank == 0: target.write_bytes(original.read_bytes()) original.unlink() comm.Barrier()
[docs] def ap_rm_file(original: Path, target: Path, comm: MPI.Comm) -> None: """Append content from `original` to `target` and delete `original`.""" rank = comm.Get_rank() if rank == 0: original_content = original.read_bytes() with target.open("ab") as f: f.write(original_content) original.unlink() comm.Barrier()
[docs] def mpi_safe_method(method): """Decorator that wraps an MPI-using method so any exception prints a traceback with the current rank and then calls MPI.Abort. The method should be a member of a class that has `comm` and `rank` attributes. """ @wraps(method) def wrapper(self, *args, **kwargs): try: return method(self, *args, **kwargs) except Exception: tb = traceback.format_exc() sys.stderr.write(f"Rank {self.rank} raised an exception:\n{tb}\n") sys.stderr.flush() self.comm.Abort(1) return wrapper
[docs] def mpi_subdomains_decomposition(n: int) -> tuple[int, int, int]: """Factor `n` into three integers `nx`, `ny`, `nz` for MPI decomposition into subdomains. `nx`, `ny` and `nz` are such that: - `nx * ny * nz == n` - the maximum of (`nx`, `ny`, `nz`) divided by the minimum is as small as possible. Parameters ---------- n : int The number of processes to decompose into subdomains. Returns ------- tuple[int, int, int] A tuple containing the number of subdomains in the x, y, and z directions. """ # Start with 1,1,n dims = [1, 1, 1] facs = repeated_prime_factors(n) # For each prime, multiply into the currently smallest dimension for p in sorted(facs, reverse=True): # Sort so that dims[0] ≤ dims[1] ≤ dims[2] dims.sort() dims[0] *= p return tuple(sorted(dims))
[docs] class MPIExceptionHandlerMixin: """Provides a common MPI exception handler method when `mpi_safe_method` cannot be used. This is useful for `__init__` and `__post_init__` methods where decorators cannot be applied because `self.comm` and `self.rank` are not yet defined before the function is called. You can use this method after the `comm` and `rank` attributes are initialized in a try/except block. """ def _handle_exception(self) -> None: tb = traceback.format_exc() sys.stderr.write(f"[Rank {self.rank}] Exception:\n{tb}\n") sys.stderr.flush() self.comm.Abort(1)
[docs] class MPITagAllocator: """A class to allocate unique tags for processes.""" _next_tag = 0
[docs] @classmethod def get_tag(cls): """Get a unique tag for the current process. Warning ------- This method should not be called in a statement that is only executed by one rank, such as in `if rank == 0:`. It is designed to be called by all ranks to ensure that all ranks receive the same tag. """ tag = cls._next_tag cls._next_tag += 1 return tag