"""
Simple parallelisation of atpass() using multiprocessing.
"""
from functools import partial
import multiprocessing
# noinspection PyProtectedMember
from ..lattice.utils import get_uint32_index
from ..lattice import AtWarning, Element, DConstant, random
from ..lattice import Refpts, End
from warnings import warn
from .atpass import reset_rng, atpass as _atpass
from .track import fortran_align
from typing import Iterable, Optional, List
import numpy as np
__all__ = ['patpass']
_imax = np.iinfo(int).max
_globring: Optional[List[Element]] = None
def format_results(results, r_in, losses):
lin, lout = zip(*results)
# Update r_in with values at the end of tracking
np.concatenate(lin, out=r_in, axis=1)
if losses:
lout, ldic = zip(*lout)
keys = ldic[0].keys()
dicout = dict(((k, np.hstack([li[k] for li in ldic])) for k in keys))
return np.concatenate(lout, axis=1), dicout
else:
return np.concatenate(lout, axis=1)
def _atpass_fork(seed, rank, rin, **kwargs):
"""Single forked job"""
reset_rng(rank, seed=seed)
result = _atpass(_globring, rin, **kwargs)
return rin, result
def _atpass_spawn(ring, seed, rank, rin, **kwargs):
"""Single spawned job"""
reset_rng(rank, seed=seed)
result = _atpass(ring, rin, **kwargs)
return rin, result
def _pass(ring, r_in, pool_size, start_method, **kwargs):
ctx = multiprocessing.get_context(start_method)
# Split input in as many slices as processes
args = enumerate(np.array_split(r_in, pool_size, axis=1))
# Generate a new starting point for C RNGs
seed = random.common.integers(0, high=_imax, dtype=int)
global _globring
_globring = ring
if ctx.get_start_method() == 'fork':
passfunc = partial(_atpass_fork, seed, **kwargs)
else:
passfunc = partial(_atpass_spawn, ring, seed, **kwargs)
# Start the parallel jobs
with ctx.Pool(pool_size) as pool:
results = pool.starmap(passfunc, args)
_globring = None
# Gather the results
losses = kwargs.pop('losses', False)
return format_results(results, r_in, losses)
[docs]@fortran_align
def patpass(lattice: Iterable[Element], r_in, nturns: int = 1,
refpts: Refpts = End, pool_size: int = None,
start_method: str = None, **kwargs):
"""
Simple parallel implementation of :py:func:`.lattice_pass`.
If more than one particle is supplied, use multiprocessing. For a
single particle or if the lattice contains :py:class:`.Collective`
elements, :py:func:`.atpass` is used.
:py:func:`patpass` tracks particles through each element of a lattice
calling the element-specific tracking function specified in the Element's
*PassMethod* field.
Parameters:
lattice: list of elements
r_in: (6, N) array: input coordinates of N particles.
*r_in* is modified in-place and reports the coordinates at
the end of the element. For the best efficiency, *r_in*
should be given as F_CONTIGUOUS numpy array.
nturns: number of turns to be tracked
refpts: Selects the location of coordinates output.
See ":ref:`Selecting elements in a lattice <refpts>`"
pool_size: number of processes. If None,
``min(npart,nproc)`` is used
start_method: python multiprocessing start method.
:py:obj:`None` uses the python default that is considered safe.
Available values: ``'fork'``, ``'spawn'``, ``'forkserver'``.
Default for linux is ``'fork'``, default for macOS and Windows is
``'spawn'``. ``'fork'`` may be used on macOS to speed up the
calculation or to solve Runtime Errors, however it is considered
unsafe.
Keyword arguments:
keep_lattice (bool): Use elements persisted from a previous
call. If :py:obj:`True`, assume that the lattice has not changed
since the previous call.
keep_counter (bool): Keep the turn number from the previous
call.
turn (int): Starting turn number. Ignored if
*keep_counter* is :py:obj:`True`. The turn number is necessary to
compute the absolute path length used in RFCavityPass.
losses (bool): Boolean to activate loss maps output
omp_num_threads (int): Number of OpenMP threads
(default: automatic)
The following keyword arguments overload the Lattice values
Keyword arguments:
particle (Particle): circulating particle.
Default: *lattice.particle* if existing,
otherwise *Particle('relativistic')*
energy (float): lattice energy. Default 0.
If *energy* is not available, relativistic tracking if forced,
*rest_energy* is ignored.
Returns:
r_out: (6, N, R, T) array containing output coordinates of N particles
at R reference points for T turns.
loss_map: If *losses* is :py:obj:`True`: dictionary with the
following key:
============== ===================================================
**islost** (npart,) bool array indicating lost particles
**turn** (npart,) int array indicating the turn at
which the particle is lost
**element** ((npart,) int array indicating the element at
which the particle is lost
**coord** (6, npart) float array giving the coordinates at
which the particle is lost (zero for surviving
particles)
============== ===================================================
.. note::
* For multiparticle tracking with large number of turn the size of
*r_out* may increase excessively. To avoid memory issues
:pycode:`lattice_pass(lattice, r_in, refpts=[])` can be used.
An empty list is returned and the tracking results of the last turn
are stored in *r_in*.
* By default, :py:func:`patpass` will use all the available CPUs.
To change the number of cores used in ALL functions using
:py:func:`patpass` (:py:mod:`~at.acceptance.acceptance` module for
example) it is possible to set ``at.DConstant.patpass_poolsize``
to the desired value.
"""
def collective(rg) -> bool:
"""True if any element involves collective effects"""
for elem in rg:
if elem.is_collective:
return True
return False
if not isinstance(lattice, list):
lattice = list(lattice)
refpts = get_uint32_index(lattice, refpts)
bunch_currents = getattr(lattice, 'bunch_currents', np.zeros(1))
bunch_spos = getattr(lattice, 'bunch_spos', np.zeros(1))
kwargs.update(bunch_currents=bunch_currents, bunch_spos=bunch_spos)
kwargs['reuse'] = kwargs.pop('keep_lattice', False)
any_collective = collective(lattice)
rshape = r_in.shape
if len(rshape) >= 2 and rshape[1] > 1 and not any_collective:
if pool_size is None:
pool_size = min(len(r_in[0]), multiprocessing.cpu_count(),
DConstant.patpass_poolsize)
return _pass(lattice, r_in, pool_size, start_method, nturns=nturns,
refpts=refpts, **kwargs)
else:
if any_collective:
warn(AtWarning('Collective PassMethod found: use single process'))
else:
warn(AtWarning('no parallel computation for a single particle'))
return _atpass(lattice, r_in, nturns=nturns, refpts=refpts, **kwargs)