Source code for at.plot.observables

from __future__ import annotations

__all__ = ["plot_observables"]

from collections.abc import Mapping, Generator
import itertools

from matplotlib.axes import Axes
from matplotlib.lines import Line2D

from ..lattice import Lattice, Refpts, All
from ..latticetools import ElementObservable, ObservableList
from .synopt import plot_synopt

_SLICES = 400


class _RingSplitter:
    """Generate a ring split in small slices for plotting."""

    def __init__(self, ring: Lattice, s_range: tuple[float, float], slices: int):
        ring.s_range = s_range
        self._ring0 = ring
        self._ring1 = None
        self.s = None
        self.slices = slices

    @property
    def ring(self) -> Lattice:
        if self._ring1:
            return self._ring1
        else:
            rg = self._ring0.slice(slices=self.slices)
            self.s = rg.get_s_pos(range(len(rg) + 1))
            self._ring1 = rg
            return rg


[docs] def plot_observables( ring: Lattice, obsleft: ObservableList, *obsright: ObservableList, s_range: tuple[float, float] | None = None, axes: Axes | None = None, xlabel: str = "", ylabel: str = "", title: str = "", slices: int = _SLICES, dipole: dict | None = {}, # noqa: B006 quadrupole: dict | None = {}, # noqa: B006 sextupole: dict | None = {}, # noqa: B006 multipole: dict | None = {}, # noqa: B006 monitor: dict | None = {}, # noqa: B006 labels: Refpts = None, **kwargs, ) -> tuple[Axes]: # noinspection PyUnresolvedReferences r"""Plot element observables along a lattice. Args: ring: Lattice description obsleft: List of :py:class:`.ElementObservable` plotted against the left axis. if refpts is :py:obj:`.All`, a line is drawn. Otherwise, markers are drawn. It is recommended to use Observables with scalar values. Otherwise, all the values are plotted but share the same line properties and legend, obsright: Optional list of :py:class:`.ElementObservable` plotted against the right axis, axes: :py:class:`~.matplotlib.axes.Axes` in which the observables are plotted. if :py:obj:`None`, a new figure is created, s_range: Lattice range of interest, default: whole lattice, slices: Number of lattice slices for getting smooth curves. Default: 400. xlabel: x-axis label. Default: ``s [m]``. ylabel: y-axis label. Default: :py:attr:`.ObservableList.axis_label`. title: Plot title, The following keywords are transmitted to the :py:func:`.plot_synopt` function.They apply to the main (left) axis and are ignored when plotting in exising axes: Keyword Args: labels (Refpts): Select elements for which the name is displayed. Default: :py:obj:`None`, dipole (dict): Dictionary of properties overloading the default properties of the dipole representation. Example: :pycode:`{"facecolor": "xkcd:electric blue"}`. If :py:obj:`None`, dipoles are not shown. quadrupole (dict): Same definition as for dipole, sextupole (dict): Same definition as for dipole, multipole (dict): Same definition as for dipole, monitor (dict): Same definition as for dipole. Returns: axes: tuple of :py:class:`~.matplotlib.axes.Axes`. Contains 2 elements if there is a plot on the right y-axis, 1 element otherwise. .. note:: The legend of the plot is controlled by the :py:attr:`.Observable.label` attributes. Default values are provided, but labels may explicitly set. Labels may contain LaTeX math code. Example: ``"$\beta_x$"`` will appear as ":math:`\beta_x`". Labels starting with an underscore do not appear in the legend. Examples: Minimal example using default values: >>> obsmu = at.ObservableList( ... [ ... at.LocalOpticsObservable(at.All, "mu", plane=0), ... at.LocalOpticsObservable(at.All, "mu", plane=1), ... ] ... ) >>> >>> (ax1,) = at.plot_observables(ring, obsmu) .. image:: /images/plot_observables/phase_obs.* :alt: phase advance plot This example demonstrates the use of the postfun post-processing attribute of observables to plot the beam envelopes for arbitrary emitttance values: >>> # Define the emittances >>> emit_x = 130.0e-12 >>> emit_y = 10.0e-12 >>> >>> # beam size computation >>> sigma_x = lambda x: 1.0e6 * np.sqrt(x * emit_x) # result in um >>> sigma_y = lambda y: 1.0e6 * np.sqrt(y * emit_y) # result in um >>> >>> # Observables >>> obsenv = at.ObservableList( ... [ ... at.LocalOpticsObservable( ... at.All, "beta", plane="x", postfun=sigma_x, label=r"$\sigma_x$" ... ), ... at.LocalOpticsObservable( ... at.All, "beta", plane="y", postfun=sigma_y, label=r"$\sigma_y$" ... ), ... ] ... ) >>> (ax2,) = at.plot_observables( ... ring, obsenv, ylabel=r"Beam size [${\mu}m$]", title="Beam envelopes" ... ) >>> ) .. image:: /images/plot_observables/envelope_obs.* :alt: envelope plot """ def get_format(obs: ElementObservable, default: str) -> tuple[tuple, Mapping]: fmt = getattr(obs, "plot_fmt", default) if isinstance(fmt, Mapping): return (), fmt else: return (fmt,), {} def axes1(axes: Axes, obs: ObservableList) -> Generator[Line2D, None, None]: """Plot all observables on a given axis.""" def plot1(obs: ElementObservable): """Plot a single observable.""" if obs.refpts is All: args, kwargs = get_format(obs, f"C{next(curve_counter)}") x = curve_s else: args, kwargs = get_format(obs, f"oC{next(dot_counter)}") x = dot_s[obs._boolrefs] return axes.plot(x, obs.value, *args, label=obs.label, **kwargs) axes.set_ylabel(obs.axis_label) for ob in obs: yield from plot1(ob) def evaluate(obs: ObservableList) -> None: """Evaluates one observable.""" curves = ObservableList(**obs.eval_kw) dots = ObservableList(**obs.eval_kw) for ob in obs: if not isinstance(ob, ElementObservable): msg = f"{ob.name} is not an ElementObservable object." raise ValueError(msg) elif ob.refpts is All: curves.append(ob) else: dots.append(ob) if curves: # Evaluate curve data curves.evaluate(ring=splitter.ring, **kwargs) if dots: # Evaluate marker data dots.evaluate(ring=ring, **kwargs) splitter = _RingSplitter(ring, s_range, slices) if axes is None: _, axleft = plot_synopt( ring, title=title, dipole=dipole, quadrupole=quadrupole, sextupole=sextupole, multipole=multipole, monitor=monitor, labels=labels, ) elif isinstance(axes, Axes): axleft = axes else: msg = "The 'axes' argument must be an Axes object." raise ValueError(msg) allaxes = (axleft, axleft.twinx()) if obsright else (axleft,) allobs = (obsleft, *obsright) for obs in allobs: evaluate(obs) curve_s = splitter.s dot_s = ring.get_s_pos(range(len(ring) + 1)) curve_counter = itertools.count() dot_counter = itertools.count() lines = [] for ax, obs in zip(allaxes, allobs, strict=True): lines.extend(axes1(ax, obs)) axleft.set_xlabel(xlabel or "s [m]") if ylabel: axleft.set_ylabel(ylabel) axleft.legend(handles=[ln for ln in lines if not ln.get_label().startswith("_")]) axleft.grid(True) return allaxes