Source code for at.plot.response

"""Plot :py:class:`.Observable` values as a function of a
:py:class:`Variable <.VariableBase>`.
"""

from __future__ import annotations

__all__ = ["plot_response"]

from collections.abc import Mapping, Iterable
import itertools

import matplotlib.pyplot as plt
from matplotlib.axes import Axes

from ..lattice import VariableBase
from ..latticetools import ObservableList


[docs] def plot_response( var: VariableBase, rng: Iterable[float], obsleft: ObservableList, *obsright: ObservableList, axes: Axes | None = None, xlabel: str = "", ylabel: str = "", **kwargs, ) -> tuple[Axes]: # noinspection PyUnresolvedReferences r"""Plot :py:class:`.Observable` values as a function of a :py:class:`Variable <.VariableBase>`. Args: var: Variable object, rng: range of variation for the variable, obsleft: List of Observables plotted on the left axis. It is recommended to use Observables with scalar values. Otherwise, all the values are plotted but share the same line properties and legend, obsright: Optional list of Observables plotted on the right axis. axes: :py:class:`~matplotlib.axes.Axes` object in which the figure is plotted. If :py:obj:`None`, a new figure is created. xlabel: x-axis label. Default: variable name. ylabel: y-axis label. Default: :py:attr:`.ObservableList.axis_label`. Additional keyword arguments are transmitted to the :py:class:`~matplotlib.axes.Axes` creation function.They apply to the main (left) axis and are ignored when plotting in exising axes: Keyword Args: title (str): Plot title, ylim (tuple): Y-axis limits, *: for other keywords see :py:obj:`~.matplotlib.figure.Figure.add_subplot` Returns: axes: tuple of :py:class:`~.matplotlib.axes.Axes`. Contains 2 elements if there is a plot on the right y-axis, 1 element otherwise. .. note:: The legend of the plot is controlled by the :py:attr:`.Observable.label` attributes. Default values are provided, but labels may explicitly set. Labels may contain LaTeX math code. Example: ``"$\beta_x$"`` will appear as ":math:`\beta_x`". Labels starting with an underscore do not appear in the legend. Example: Minimal example using only default values: >>> obsl = at.ObservableList( ... [at.EmittanceObservable("emittances", plane="x")], ... ring=ring, ... ) >>> obsr = at.ObservableList( ... [at.EmittanceObservable("sigma_e")], ... ring=ring, ... ) >>> var = at.AttributeVariable(ring, "energy", name="energy [eV]") >>> ax1, ax2 = at.plot_response( ... var, np.arange(3.0e9, 6.01e9, 0.5e9), obsl, obsr ... ) >>> .. image:: /images/emittance_response.* :alt: emittance response Example showing the formatting possibilities by: - using the :py:attr:`.Observable.plot_fmt` attribute for line formatting, - using dual y-axis, - using the *ylim* and *title* parameters. >>> obsleft = at.ObservableList( ... [ ... at.LocalOpticsObservable( ... [0], "beta", plane="x", ... plot_fmt={"linewidth": 3.0, "marker": "o"} ... ), ... at.LocalOpticsObservable([0], "beta", plane="y", plot_fmt="--"), ... ], ... ring=ring, ... ) >>> >>> obsright = at.ObservableList( ... [ ... at.GlobalOpticsObservable("tune", plane="x"), ... at.GlobalOpticsObservable("tune", plane="y"), ... ], ... ring=ring, ... ) >>> >>> var = RefptsVariable( ... "QF1[AE]", "Kn1L", name="QF1 integrated strength", ring=ring ... ) >>> ax = at.plot_response( ... var, ... np.arange(0.732, 0.852, 0.01), ... obsleft, ... obsright, ... ylim=[0.0, 10.0], ... title="Example of plot_response", ... ) .. image:: /images/beta_response.* :alt: beta response Example varying an evaluation parameter: >>> obs = at.ObservableList( ... [ ... at.LocalOpticsObservable([0], "beta", plane="x"), ... at.LocalOpticsObservable([0], "beta", plane="y"), ... ], ... ring=ring, ... dp=0.0, ... ) >>> var = at.EvaluationVariable(obsleft, "dp", name=r"$\delta$") >>> ax = at.plot_response(var, np.arange(-0.03, 0.0301, 0.001), obsleft) .. image:: /images/delta_response.* :alt: delta response """ def compute(v, obs): """Evaluate the observables for 1 variable value.""" var.value = v for ob in obs: yield from ob.evaluate() def axes1(axes: Axes, obs: ObservableList): """Plot all observables on a given axis.""" def plot1(obs, ncurve): """Plot 1 curve.""" fmt = getattr(obs, "plot_fmt", f"C{ncurve}") if isinstance(fmt, Mapping): return axes.plot(xx, next(values), label=obs.label, **fmt) else: return axes.plot(xx, next(values), fmt, label=obs.label) axes.set_ylabel(obs.axis_label) for ob in obs: yield from plot1(ob, next(line_counter)) if isinstance(rng, ObservableList): # swap arguments for the old argument order rng, obsleft = obsleft, rng if axes is None: _, axleft = plt.subplots(subplot_kw=kwargs) elif isinstance(axes, Axes): axleft = axes else: msg = "The 'axes' argument must be an Axes object." raise ValueError(msg) allaxes = (axleft, axleft.twinx()) if obsright else (axleft,) allobs = (obsleft, *obsright) # Compute all the observable values with var.restore(): vals = [(v, *compute(v, allobs)) for v in rng] xx, *yy = zip(*vals, strict=True) values = iter(yy) line_counter = itertools.count() lines = [] for ax, obs in zip(allaxes, allobs, strict=True): lines.extend(axes1(ax, obs)) axleft.set_xlabel(xlabel or var.name) if ylabel: axleft.set_ylabel(ylabel) axleft.legend(handles=lines) axleft.grid(True) return allaxes