Source code for at.latticetools.observablelist

r"""Grouping of :py:class:`.Observable` objects for fast evaluation.

:py:class:`ObservableList`\ (lattice, ...)
    This container based on :py:class:`list` is in charge of optics
    computations and provides each individual :py:class:`.Observable` with its
    specific data.

:py:class:`ObservableList` provides the :py:meth:`~ObservableList.evaluate`
method and the following properties:

- :py:attr:`~ObservableList.values`
- :py:attr:`~ObservableList.flat_values`
- :py:attr:`~ObservableList.weights`
- :py:attr:`~ObservableList.flat_weights`
- :py:attr:`~ObservableList.weighted_values`
- :py:attr:`~ObservableList.flat_weighted_values`
- :py:attr:`~ObservableList.deviations`
- :py:attr:`~ObservableList.flat_deviations`
- :py:attr:`~ObservableList.residuals`
- :py:attr:`~ObservableList.sum_residuals`
"""

from __future__ import annotations

__all__ = [
    "ObservableList",
]

from collections.abc import Iterable, Iterator
from functools import reduce
from typing import Callable

import numpy as np
import numpy.typing as npt

# noinspection PyProtectedMember
from .observables import Observable, ElementObservable, Need
from .rdt_observable import RDTObservable
from ..lattice import AtError, frequency_control
from ..lattice import Lattice, Orbit, Refpts, All
from ..physics import linopt6
from ..tracking import internal_lpass


def _flatten(vals, order="F") -> npt.NDArray[float]:
    return np.concatenate([np.reshape(v, -1, order=order) for v in vals])


class _ObsResIter(Iterator):
    """Iterator object for the _ObsResult class"""

    def __init__(self, obsiter):
        self.base = obsiter

    def __next__(self):
        # Raises the stored error when reaching a missing value
        return Observable.check_value(next(self.base))


class _ObsResults(tuple):
    """Tuple-like object for the output of  ObservableList.evaluate

    _ObsResult implements a special treatment when the evaluation ends with an error.
    The error is stored instead of the value. This object raises the error a posteriori,
    when accessing the missing value.
    """

    def __getitem__(self, item):
        # Raises the stored error when accessing a missing value
        if isinstance(item, slice):
            return _ObsResults(super().__getitem__(item))
        else:
            return Observable.check_value(super().__getitem__(item))

    def __iter__(self):
        return _ObsResIter(super().__iter__())

    def __repr__(self):
        return repr(tuple(super().__iter__()))


[docs] class ObservableList(list): """Handles a list of Observables to be evaluated together. :py:class:`ObservableList` supports all :py:class:`list` methods, like appending, insertion or concatenation with the "+" operator. """ needs_ring = { Need.RING, Need.ORBIT, Need.MATRIX, Need.GLOBALOPTICS, Need.LOCALOPTICS, Need.TRAJECTORY, Need.EMITTANCE, Need.GEOMETRY, } needs_orbit = { Need.ORBIT, Need.MATRIX, Need.GLOBALOPTICS, Need.LOCALOPTICS, Need.EMITTANCE, } needs_optics = {Need.GLOBALOPTICS, Need.LOCALOPTICS} def __init__( self, obsiter: Iterable[Observable] = (), *, method: Callable = linopt6, orbit: Orbit = None, twiss_in=None, r_in: Orbit = None, **kwargs, ): # noinspection PyUnresolvedReferences r""" Args: obsiter: Iterable of :py:class:`.Observable`\ s Keyword Args: orbit (Orbit): Initial orbit. Avoids looking for the closed orbit if it is already known. Used for :py:class:`MatrixObservable` and :py:class:`LocalOpticsObservable` twiss_in: Initial conditions for transfer line optics. See :py:func:`.get_optics`. Used for :py:class:`LocalOpticsObservable` method (Callable): Method for linear optics. Used for :py:class:`LocalOpticsObservable`. Default: :py:obj:`~.linear.linopt6` r_in (Orbit): Initial trajectory, used for :py:class:`TrajectoryObservable`, Default: zeros(6) Example: >>> obslist = ObservableList() Create an empty Observable list >>> obslist.append(OrbitObservable(at.Monitor, plane="x")) >>> obslist.append(GlobalOpticsObservable("tune")) >>> obslist.append(EmittanceObservable("emittances", plane="h")) Add observables for horizontal closed orbit at Monitor locations, tunes and horizontal emittance >>> obslist.evaluate(ring, initial=True) Compute the values and mark them as the initial value >>> obslist.values [array([-3.02189464e-09, 4.50695130e-07, 4.08205818e-07, 2.37899969e-08, -1.31783561e-08, 2.47230794e-08, 2.95310770e-08, -4.05598110e-07, -4.47398092e-07, 2.24850671e-09]), array([3.81563019e-01, 8.54376397e-01, 1.09060730e-04]), 1.320391045951568e-10] >>> obslist.get_flat_values("tune", "emittances[h]") array([3.815630e-01, 8.543764e-01, 1.090607e-04, 1.320391e-10]) Get a flattened array of tunes and horizontal emittance """ self.orbitrefs = None self.opticsrefs = None self.passrefs = None self.matrixrefs = None self.rdtrefs = None self.needs = None self.rdt_type = set() self.method = method self.orbit = orbit self.twiss_in = twiss_in self.r_in = r_in self.kwargs = kwargs super().__init__(obsiter) # noinspection PyProtectedMember def _setup(self, ring: Lattice): # Compute the union of all needs needs = set() rdt_type = set() for obs in self: needs |= obs.needs if isinstance(obs, ElementObservable): needs.add(Need.RING) if isinstance(obs, RDTObservable): rdt_type.add(obs._rdt_type) if (needs & self.needs_ring) and ring is None: raise ValueError("At least one Observable needs a ring argument") self.needs = needs self.rdt_type = rdt_type if ring is None: # Initialise each observable for obs in self: obs._setup(ring) else: # Initialise each observable and make a summary all refpoints noref = ring.get_bool_index(None) orbitrefs = opticsrefs = passrefs = matrixrefs = rdtrefs = noref for obs in self: obs._setup(ring) obsneeds = obs.needs if isinstance(obs, ElementObservable): if Need.ORBIT in obsneeds: orbitrefs |= obs._boolrefs if Need.MATRIX in obsneeds: matrixrefs |= obs._boolrefs if Need.LOCALOPTICS in obsneeds: if Need.ALL_POINTS in obsneeds: opticsrefs = ring.get_bool_index(All) else: opticsrefs |= obs._boolrefs if Need.TRAJECTORY in obsneeds: passrefs |= obs._boolrefs if Need.RDT in obsneeds: rdtrefs |= obs._boolrefs self.orbitrefs = orbitrefs self.opticsrefs = opticsrefs self.rdtrefs = rdtrefs self.passrefs = passrefs self.matrixrefs = matrixrefs def __iadd__(self, other: ObservableList): if not isinstance(other, ObservableList): raise TypeError(f"Cannot add a {type(other)} to an ObservableList") self.extend(other) return self def __add__(self, other) -> ObservableList: nobs = ObservableList(self) nobs += other return nobs
[docs] def append(self, obs: Observable): """Append observable to the end of the list.""" if not isinstance(obs, Observable): raise TypeError(f"Cannot append a {type(obs)} to an ObservableList") self.needs = None super().append(obs)
[docs] def extend(self, obsiter: Iterable[Observable]): """Extend list by appending Observables from the iterable.""" self.needs = None super().extend(obsiter)
[docs] def insert(self, index: int, obs: Observable): """Insert Observable before index.""" if not isinstance(obs, Observable): raise TypeError(f"Cannot insert a {type(obs)} in an ObservableList") self.needs = None super().insert(index, obs)
# noinspection PyProtectedMember def __str__(self): values = "\n".join(obs._all_lines() for obs in self) return "\n".join((Observable._header(), values))
[docs] def evaluate( self, ring: Lattice | None = None, *, dp: float | None = None, dct: float | None = None, df: float | None = None, initial: bool = False, **kwargs, ): r"""Compute all the :py:class:`Observable` values. Args: ring: Lattice used for evaluation dp (float): Momentum deviation. Defaults to :py:obj:`None` dct (float): Path lengthening. Defaults to :py:obj:`None` df (float): Deviation from the nominal RF frequency. Defaults to :py:obj:`None` initial: If :py:obj:`True`, store the values as *initial values* Keyword Args: orbit (Orbit): Initial orbit. Avoids looking for the closed orbit if it is already known. Used for :py:class:`.MatrixObservable` and :py:class:`.LocalOpticsObservable` twiss_in: Initial conditions for transfer line optics. See :py:func:`.get_optics`. Used for :py:class:`.LocalOpticsObservable` method (Callable): Method for linear optics. Used for :py:class:`.LocalOpticsObservable`. Default: :py:obj:`~.linear.linopt6` r_in (Orbit): Initial trajectory, used for :py:class:`.TrajectoryObservable`, Default: zeros(6) """ def obseval(ring, obs): """Evaluate a single observable.""" def check_error(data, refpts): return data if isinstance(data, Exception) else data[refpts] obsneeds = obs.needs obsrefs = getattr(obs, "_boolrefs", None) data = [] if Need.RING in obsneeds: data.append(ring) if Need.ORBIT in obsneeds: data.append(check_error(orbits, obsrefs[self.orbitrefs])) if Need.MATRIX in obsneeds: data.append(check_error(mxdata, obsrefs[self.matrixrefs])) if Need.GLOBALOPTICS in obsneeds: data.append(rgdata) if Need.LOCALOPTICS in obsneeds: data.append(check_error(eldata, obsrefs[self.opticsrefs])) if Need.TRAJECTORY in obsneeds: data.append(trajs[obsrefs[self.passrefs]]) if Need.EMITTANCE in obsneeds: data.append(emdata) if Need.GEOMETRY in obsneeds: data.append(geodata[obsrefs]) if Need.RDT in obsneeds: data.append(check_error(rdtdata, obsrefs[self.rdtrefs])) return obs.evaluate(*data, initial=initial) @frequency_control def ringeval( ring, dp: float | None = None, dct: float | None = None, df: float | None = None, ): """Optics computations.""" keep_lattice = False trajs = orbits = rgdata = eldata = emdata = mxdata = geodata = rdtdata = ( None ) twiss_in = kwargs.get("twiss_in", self.twiss_in) o0 = kwargs.get("orbit", self.orbit) o0 = getattr(twiss_in, "closed_orbit", None) if o0 is None else o0 needs = self.needs needs_o0 = (needs & self.needs_orbit) and (o0 is None) if Need.TRAJECTORY in needs: # Trajectory computation r_in = kwargs.get("r_in", self.r_in) if r_in is None: r_in = np.zeros(6) r_out = internal_lpass(ring, r_in.copy(), 1, refpts=self.passrefs) trajs = r_out[:, 0, :, 0].T keep_lattice = True # if needs & self.needs_orbit: if Need.ORBIT in needs or needs_o0: # Closed orbit computation try: o0, orbits = ring.find_orbit( refpts=self.orbitrefs, dp=dp, dct=dct, df=df, orbit=o0, keep_lattice=keep_lattice, ) except AtError as err: orbits = mxdata = rgdata = eldata = emdata = err else: keep_lattice = True if Need.MATRIX in needs and o0 is not None: # Transfer matrix computation find_m = ring.find_m66 if ring.is_6d else ring.find_m44 # noinspection PyUnboundLocalVariable _, mxdata = find_m( refpts=self.matrixrefs, dp=dp, dct=dct, df=df, orbit=o0, keep_lattice=keep_lattice, ) keep_lattice = True if (needs & self.needs_optics) and o0 is not None: # Linear optics computation try: _, rgdata, eldata = ring.get_optics( refpts=self.opticsrefs, dp=dp, dct=dct, df=df, orbit=o0, keep_lattice=keep_lattice, get_chrom=Need.CHROMATICITY in needs, get_w=Need.W_FUNCTIONS in needs, twiss_in=twiss_in, method=kwargs.get("method", self.method), ) except AtError as err: rgdata = eldata = err else: keep_lattice = True if Need.EMITTANCE in needs and o0 is not None: # Emittance computation try: emdata = ring.envelope_parameters( orbit=o0, keep_lattice=keep_lattice ) except Exception as err: emdata = err if Need.GEOMETRY in needs: # Geometry computation geodata, _ = ring.get_geometry() if Need.RDT in needs: # RDT computation use_mp = kwargs.get("use_mp", False) pool_size = kwargs.get("pool_size", None) try: _, _, rdtdata = ring.get_rdts( refpts=self.rdtrefs, rdt_type=self.rdt_type, second_order=Need.RDT_2ND_ORDER in needs, use_mp=use_mp, pool_size=pool_size, ) except Exception as err: rdtdata = err return trajs, orbits, rgdata, eldata, emdata, mxdata, geodata, rdtdata if self.needs is None or initial: self._setup(ring) trajs, orbits, rgdata, eldata, emdata, mxdata, geodata, rdtdata = ringeval( ring, dp=dp, dct=dct, df=df ) return _ObsResults(obseval(ring, ob) for ob in self)
[docs] def check(self) -> bool: """Check if all observables are evaluated. Returns: ok: :py:obj:`True` if evaluation is done, :py:obj:`False` otherwise Raises: AtError: any value is doubtful: evaluation failed, empty value… """ return all(obs.check() for obs in self)
# noinspection PyProtectedMember
[docs] def exclude(self, obsname: str, excluded: Refpts): """Set the excluded mask on the selected observable.""" for obs in self: if obs.name == obsname: obs._excluded = excluded self.needs = None
def _lookup(self, *ids: int | str) -> list[Observable]: """Observable lookup function""" def select(id): if isinstance(id, str): for obs in self: if obs.name == id: return obs else: raise KeyError(id) else: return self[id] if ids: return [select(id) for id in ids] else: return self def _collect(self, attrname: str, *obsid: str | int, err: float | None = None): def val(obs): try: res = getattr(obs, attrname) except AtError: if err is None: raise else: # noinspection PyProtectedMember shp = obs._shape res = err if shp is None else np.broadcast_to(err, shp) return res obslist = self._lookup(*obsid) return tuple(val(obs) for obs in obslist)
[docs] def get_shapes(self, *obsid: str | int) -> tuple: """Return the shapes of all values. Args: *obsid: name or index of selected observables (Default all) """ return self._collect("_shape", *obsid)
[docs] def get_flat_shape(self, *obsid: str | int): """Return the shape of the flattened values. Args: *obsid: name or index of selected observables (Default all) """ vals = ( reduce(lambda x, y: x * y, shp, 1) for shp in self._collect("_shape", *obsid) ) return (sum(vals),)
[docs] def get_values(self, *obsid: str | int, err: float | None = None) -> tuple: """Return the values of observables. Args: *obsid: name or index of selected observables (Default all) err: Default observable value to be used when the evaluation failed. By default, an Exception is raised. Raises: Exception: Any exception raised during evaluation, unless *err* has been set. """ return self._collect("value", *obsid, err=err)
[docs] def get_flat_values( self, *obsid: str | int, err: float | None = None, order: str = "F" ) -> npt.NDArray[float]: """Return a 1-D array of Observable values. Args: *obsid: name or index of selected observables (Default all) err: Default observable value to be used when the evaluation failed. By default, an Exception is raised. order: Ordering for reshaping. See :py:func:`~numpy.reshape` """ return _flatten(self._collect("value", *obsid, err=err), order=order)
[docs] def get_weighted_values(self, *obsid: str | int, err: float | None = None) -> tuple: """Return the weighted values of observables. Args: *obsid: name or index of selected observables (Default all) err: Default observable value to be used when the evaluation failed. By default, an Exception is raised. """ return self._collect("weighted_value", *obsid, err=err)
[docs] def get_flat_weighted_values( self, *obsid: str | int, err: float | None = None, order: str = "F" ) -> npt.NDArray[float]: """Return a 1-D array of Observable weighted values. Args: *obsid: name or index of selected observables (Default all) err: Default observable value to be used when the evaluation failed. By default, an Exception is raised. order: Ordering for reshaping. See :py:func:`~numpy.reshape` """ return _flatten(self._collect("weighted_value", *obsid, err=err), order=order)
[docs] def get_deviations(self, *obsid: str | int, err: float | None = None) -> tuple: """Return the deviations from target values. Args: *obsid: name or index of selected observables (Default all) err: Default observable value to be used when the evaluation failed. By default, an Exception is raised. """ return self._collect("deviation", *obsid, err=err)
[docs] def get_flat_deviations( self, *obsid: str | int, err: float | None = None, order: str = "F" ) -> npt.NDArray[float]: """Return a 1-D array of deviations from target values. Args: *obsid: name or index of selected observables (Default all) err: Default observable value to be used when the evaluation failed. By default, an Exception is raised. order: Ordering for reshaping. See :py:func:`~numpy.reshape` """ return _flatten(self._collect("deviation", *obsid, err=err), order=order)
[docs] def get_weighted_deviations( self, *obsid: str | int, err: float | None = None ) -> tuple: """Return the weighted deviations from target values. Args: *obsid: name or index of selected observables (Default all) err: Default observable value to be used when the evaluation failed. By default, an Exception is raised. """ return self._collect("weighted_deviation", *obsid, err=err)
[docs] def get_flat_weighted_deviations( self, *obsid: str | int, err: float | None = None, order: str = "F" ) -> npt.NDArray[float]: """Return a 1-D array of weighted deviations from target values. Args: *obsid: name or index of selected observables (Default all) err: Default observable value to be used when the evaluation failed. By default, an Exception is raised. order: Ordering for reshaping. See :py:func:`~numpy.reshape` """ return _flatten( self._collect("weighted_deviation", *obsid, err=err), order=order )
[docs] def get_weights(self, *obsid: str | int) -> tuple: """Return the weights of observables. Args: *obsid: name or index of selected observables (Default all) """ return self._collect("weight", *obsid)
[docs] def get_flat_weights( self, *obsid: str | int, order: str = "F" ) -> npt.NDArray[float]: """Return a 1-D array of Observable weights. Args: *obsid: name or index of selected observables (Default all) order: Ordering for reshaping. See :py:func:`~numpy.reshape` """ return _flatten(self._collect("weight", *obsid), order=order)
[docs] def get_residuals(self, *obsid: str | int, err: float | None = None) -> tuple: """Return the residuals of observables. Args: *obsid: name or index of selected observables (Default all) err: Default observable value to be used when the evaluation failed. By default, an Exception is raised. """ return self._collect("residual", *obsid, err=err)
[docs] def get_sum_residuals(self, *obsid: str | int, err: float | None = None) -> float: """Return the sum of residual values. Args: *obsid: name or index of selected observables (Default all) err: Default observable value to be used when the evaluation failed. By default, an Exception is raised. """ return sum(np.sum(res) for res in self._collect("residual", *obsid, err=err))
shapes = property(get_shapes, doc="Shapes of all values") flat_shape = property(get_flat_shape, doc="Shape of the flattened values") values = property(get_values, doc="values of all observables") flat_values = property(get_flat_values, doc="1-D array of Observable values") weighted_values = property( get_weighted_values, doc="Weighted values of all observables" ) flat_weighted_values = property( get_flat_weighted_values, doc="1-D array of Observable weigthed values" ) deviations = property(get_deviations, doc="Deviations from target values") flat_deviations = property( get_flat_deviations, doc="1-D array of deviations from target value" ) weighted_deviations = property( get_weighted_deviations, doc="Weighted deviations from target values" ) flat_weighted_deviations = property( get_flat_weighted_deviations, doc="1-D array of weighted deviations from target values", ) flat_weights = property(get_flat_weights, doc="1-D array of Observable weights") residuals = property(get_residuals, doc="Residuals of all observable") sum_residuals = property(get_sum_residuals, doc="Sum of all residual values")