r"""
Definition of :py:class:`Variable <.VariableBase>` objects used in matching and
response matrices.
See :ref:`example-notebooks` for examples of matching and response matrices.
Each :py:class:`Variable <.VariableBase>` has a scalar value.
.. rubric:: Class hierarchy
:py:class:`VariableBase`\ (name, bounds, delta)
- :py:class:`~.lattice_variables.ElementVariable`\ (elements, attrname, index, ...)
- :py:class:`~.lattice_variables.RefptsVariable`\ (refpts, attrname, index, ...)
- :py:class:`CustomVariable`\ (setfun, getfun, ...)
.. rubric:: VariableBase methods
:py:class:`VariableBase` provides the following methods:
- :py:meth:`~VariableBase.get`
- :py:meth:`~VariableBase.set`
- :py:meth:`~VariableBase.set_previous`
- :py:meth:`~VariableBase.reset`
- :py:meth:`~VariableBase.increment`
- :py:meth:`~VariableBase.step_up`
- :py:meth:`~VariableBase.step_down`
.. rubric:: VariableBase properties
:py:class:`.VariableBase` provides the following properties:
- :py:attr:`~VariableBase.initial_value`
- :py:attr:`~VariableBase.last_value`
- :py:attr:`~VariableBase.previous_value`
- :py:attr:`~VariableBase.history`
The :py:class:`VariableBase` abstract class may be used as a base class to define
custom variables (see examples). Typically, this consist in overloading the abstract
methods *_setfun* and *_getfun*
.. rubric:: Examples
Write a subclass of :py:class:`VariableBase` which varies two drift lengths so
that their sum is constant:
.. code-block:: python
class ElementShifter(at.VariableBase):
'''Varies the length of the elements identified by *ref1* and *ref2*
keeping the sum of their lengths equal to *total_length*.
If *total_length* is None, it is set to the initial total length
'''
def __init__(self, drift1, drift2, total_length=None, **kwargs):
# store the 2 variable elements
self.drift1 = drift1
self.drift2 = drift2
# store the initial total length
if total_length is None:
total_length = drift1.Length + drift2.Length
self.length = total_length
super().__init__(bounds=(0.0, total_length), **kwargs)
def _setfun(self, value, **kwargs):
self.drift1.Length = value
self.drift2.Length = self.length - value
def _getfun(self, **kwargs):
return self.drift1.Length
And create a variable varying the length of drifts *DR_01* and *DR_01* and
keeping their sum constant:
.. code-block:: python
drift1 = hmba_lattice["DR_01"]
drift2 = hmba_lattice["DR_02"]
var2 = ElementShifter(drift1, drift2, name="DR_01")
"""
from __future__ import annotations
__all__ = [
"AttributeVariable",
"CustomVariable",
"ItemVariable",
"VariableBase",
"VariableList",
"attr_",
"membergetter",
]
import abc
from collections import deque
from collections.abc import (
Iterable,
Sequence,
Callable,
Generator,
MutableMapping,
MutableSequence,
)
from contextlib import suppress, contextmanager
from operator import itemgetter, attrgetter
import numpy as np
import numpy.typing as npt
from .parambase import Number, Operand
[docs]
class attr_(str):
"""subclass of :py:class:`str` used to differentiate directory keys and
attribute names.
"""
__slots__ = []
class _AttributeAccessor:
"""Class object setting/getting an attribute of an object."""
__slots__ = ["attrname", "obj"]
def __init__(self, obj, attrname: str):
self.obj = obj
self.attrname = attrname
def set(self, value: float):
setattr(self.obj, self.attrname, value)
def get(self) -> float:
return getattr(self.obj, self.attrname)
class _ItemAccessor:
"""Class object setting/getting an item of an object."""
__slots__ = ["key", "obj"]
def __init__(self, obj: MutableMapping | MutableSequence, key):
self.obj = obj
self.key = key
def set(self, value: float):
self.obj[self.key] = value
def get(self) -> float:
return self.obj[self.key]
[docs]
class membergetter:
"""Generalised attribute and item lookup.
Callable object fetching attributes or items from its operand object. This
generalises :py:func:`~operator.attrgetter` and :py:func:`~operator.itemgetter` and
allows to extract elements deep in the object structure. For example:
- With ``f1 = membergetter("key1", [2])``, then ``f1(obj)`` returns
``obj["key1"][2]``,
- With ``f2 = membergetter("key2", attr_("attr1"), [key3])``, then ``f2(obj)``
returns ``obj["key2"].attr1["key3"]``.
"""
def __init__(self, *args):
r"""
Args:
*args: Sequence of dictionary keys, sequence indices or
attribute names. A :py:class:`str` argument is interpreted as a
dictionary key. Attribute names must be decorated with ``attr_(attrname)``
to distinguish them from directory keys.
- ``f1 = SetGet("key1")(obj)`` returns ``obj["key1"]``,
- ``f2 = SetGet(attr("attr1"))(obj)`` returns ``obj.attr1``
Example:
>>> dct = {"a": 42.0, "b": [0.0, 1.0, 2.0, 3.0]}
>>> f = membergetter("b", 1)
>>> f(dct)
1.0
"""
def getter(key):
return attrgetter(key) if isinstance(key, attr_) else itemgetter(key)
self.key = args[-1]
self.getters = [getter(key) for key in args]
def __call__(self, obj):
for getter in self.getters:
obj = getter(obj)
return obj
[docs]
def accessor(self, obj):
"""Return an accessor object.
The returned object has *set* and *get* methods acting on the selected item of
the object *obj*.
Example:
>>> dct = {"a": 42.0, "b": [0.0, 1.0, 2.0, 3.0]}
>>> v2 = membergetter("b", 1)
>>> accessor = v2.accessor(dct)
>>> accessor.get()
1.0
>>> accessor.set(42.0)
>>> dct
{'a': 42.0, 'b': [0.0, 42.0, 2.0, 3.0]}
"""
for getter in self.getters[:-1]:
obj = getter(obj)
if isinstance(self.key, attr_):
return _AttributeAccessor(obj, self.key)
else:
return _ItemAccessor(obj, self.key)
[docs]
class VariableBase(Operand, abc.ABC):
"""A Variable abstract base class.
Derived classes must implement the :py:meth:`~VariableBase._getfun` and
:py:meth:`~VariableBase._setfun` methods
"""
# Class constants
DEFAULT_DELTA = 1.0
_COUNTER_PREFIX = "var"
_counter = 0
def __init__(
self,
*args,
name: str = "",
bounds: tuple[Number | None, Number | None] | None = None,
delta: Number = DEFAULT_DELTA,
history_length: int | None = None,
**kwargs,
) -> None:
"""
Parameters:
*args: Positional arguments passed to the _setfun and _getfun methods
name: Name of the Variable. If omitted or blank, a unique name is
generated.
bounds: Lower and upper bounds of the variable value
delta: Initial variation step
history_length: Maximum length of the history buffer. :py:obj:`None`
means infinite.
Keyword Args:
**kwargs: Keyword arguments passed to the _setfun and _getfun methods
"""
self.name: str = self._generate_name(name) #: Variable name
self.args = args
self.kwargs = kwargs
if bounds is None:
bounds = (None, None)
self._bounds: tuple[Number | None, Number | None] = bounds #: Variable bounds
self.delta: Number = delta #: Increment step
#: Maximum length of the history buffer. :py:obj:`None` means infinite
self.history_length = history_length
self._initial = np.nan
self._history: deque[Number] = deque([], self.history_length)
super().__init__(name=self.name)
with suppress(ValueError):
self.get(initial=True)
@classmethod
def _generate_name(cls, name: str) -> str:
"""Generate unique name for variable."""
cls._counter += 1
return name if name else f"{cls._COUNTER_PREFIX}{cls._counter}"
@property
def bounds(self) -> tuple[float, float]:
"""Lower and upper bounds of the variable."""
vmin, vmax = self._bounds
return -np.inf if vmin is None else vmin, np.inf if vmax is None else vmax
[docs]
def check_bounds(self, value: Number) -> None:
"""Check that a value is within the variable bounds.
Raises:
ValueError: If the value is not within bounds
"""
min_val, max_val = self._bounds
if min_val is not None and value < min_val:
msg = f"Value {value} must be larger or equal to {min_val}"
raise ValueError(msg)
if max_val is not None and value > max_val:
msg = f"Value {value} must be smaller or equal to {max_val}"
raise ValueError(msg)
def _setfun(self, value: Number, *args, **kwargs) -> None:
classname = self.__class__.__name__
msg = f"{classname!r} is read-only"
raise TypeError(msg)
@abc.abstractmethod
def _getfun(self, *args, **kwargs) -> Number: ...
@property
def history(self) -> list[Number]:
"""History of the values of the variable."""
return list(self._history)
@property
def initial_value(self) -> Number:
"""Initial value of the variable.
Raises:
IndexError: If no initial value has been set yet
"""
if not np.isnan(self._initial):
return self._initial
else:
msg = f"{self.name}: No initial value has been set yet"
raise IndexError(msg)
@property
def last_value(self) -> Number:
"""Last value of the variable.
Raises:
IndexError: If no value has been set yet
"""
try:
return self._history[-1]
except IndexError as exc:
exc.args = (f"{self.name}: No value has been set yet",)
raise
@property
def previous_value(self) -> Number:
"""Value before the last one.
Raises:
IndexError: If there are fewer than 2 values in the history
"""
try:
return self._history[-2]
except IndexError as exc:
exc.args = (f"{self.name}: history too short (need at least 2 values)",)
raise
[docs]
def set(self, value: Number, **setkw) -> None:
r"""Set the variable value.
Args:
value: New value to be applied on the variable
Keyword Args:
\*\*setkw: Keyword arguments to be passed to the *setfun* function.
They augment the keyword arguments given in the constructor.
"""
self.check_bounds(value)
self._setfun(value, *self.args, **(self.kwargs | setkw))
if np.isnan(self._initial):
self._initial = value
self._history.append(value)
[docs]
def get(
self, *, initial: bool = False, check_bounds: bool = False, **getkw
) -> Number:
r"""Get the actual variable value.
Args:
initial: If :py:obj:`True`, clear the history and set the variable
initial value
check_bounds: If :py:obj:`True`, raise a ValueError if the value is out
of bounds
Keyword Args:
\*\*getkw: Keyword arguments to be passed to the *getfun* function.
They augment the keyword arguments given in the constructor.
Returns:
value: Value of the variable
"""
value = self._getfun(*self.args, **(self.kwargs | getkw))
if initial or np.isnan(self._initial):
self._initial = value
self._history = deque([value], self.history_length)
if check_bounds:
self.check_bounds(value)
return value
value = property(get, set, doc="Actual value")
@property
def _print_value(self):
try:
return self._history[-1]
except IndexError:
return np.nan
[docs]
def set_previous(self, **setkw) -> None:
r"""Reset to the value before the last one.
Keyword Args:
\*\*setkw: Keyword arguments to be passed to the *setfun* function.
They augment the keyword arguments given in the constructor.
Raises:
IndexError: If there are fewer than 2 values in the history
"""
if len(self._history) >= 2:
self._history.pop() # Remove the last value
previous_value = self._history.pop() # retrieve the previous value
self.set(previous_value, **setkw)
else:
msg = f"{self.name}: history too short (need at least 2 values)"
raise IndexError(msg)
[docs]
def reset(self, **setkw) -> None:
r"""Reset to the initial value and clear the history buffer.
Keyword Args:
\*\*setkw: Keyword arguments to be passed to the *setfun* function.
They augment the keyword arguments given in the constructor.
Raises:
IndexError: If no initial value has been set yet
"""
initial_value = self._initial
if not np.isnan(initial_value):
self._history = deque([], self.history_length)
self.set(initial_value, **setkw)
else:
msg = f"Cannot reset {self.name}: No initial value has been set yet"
raise IndexError(msg)
[docs]
def increment(self, incr: Number, **setkw) -> None:
r"""Increment the variable value.
Args:
incr: Increment value
Keyword Args:
\*\*setkw: Keyword arguments to be passed to the *setfun* function.
They augment the keyword arguments given in the constructor.
"""
try:
current_value = self.last_value
except IndexError:
# If no value has been set yet, get the initial value
self.get(initial=True, **setkw)
current_value = self.last_value
self.set(current_value + incr, **setkw)
def _step(self, step: Number, **setkw) -> None:
"""Apply a step relative to the initial value."""
try:
initial_value = self.initial_value
except IndexError:
# If no initial value has been set yet, get it
self.get(initial=True, **setkw)
initial_value = self.initial_value
self.set(initial_value + step, **setkw)
[docs]
def step_up(self, **setkw) -> None:
r"""Set to initial_value + delta.
Keyword Args:
\*\*setkw: Keyword arguments to be passed to the *setfun* function.
They augment the keyword arguments given in the constructor.
"""
self._step(self.delta, **setkw)
[docs]
def step_down(self, **setkw) -> None:
r"""Set to initial_value - delta.
Keyword Args:
\*\*setkw: Keyword arguments to be passed to the *setfun* function.
They augment the keyword arguments given in the constructor.
"""
self._step(-self.delta, **setkw)
@staticmethod
def _header():
return "\n{:>12s}{:>13s}{:>16s}{:>16s}\n".format(
"Name", "Initial", "Final ", "Variation"
)
def _line(self):
vnow = self._print_value
vini = self._initial
return f"{self.name:>12s}{vini: 16e}{vnow: 16e}{vnow - vini: 16e}"
[docs]
def status(self):
"""Return a string describing the current status of the variable.
Returns:
status: Variable description
"""
return "\n".join((self._header(), self._line()))
[docs]
@contextmanager
def restore(self, initial: bool = False, **setkw) -> Generator[None, None, None]:
# noinspection PyUnresolvedReferences
r"""Context manager that saves and restore a variable.
The value of the :py:class:`Variable <VariableBase>` is initially saved, and
then restored when exiting the context.
Keyword Args:
initial: If :py:obj:`True`, clear the history and set the variable
initial value
\*\*setkw: Keyword arguments to be passed to the *setfun* function.
They augment the keyword arguments given in the constructor.
Example:
>>> var = AttributeVariable(ring, "energy")
>>> with var.restore():
... do_something(var)
"""
# print(f"save {self.name}")
v0 = self.get(initial=initial, **setkw)
try:
yield
finally:
# print(f"restore {self.name}")
self.set(v0, **setkw)
def __float__(self):
return float(self._print_value)
def __int__(self):
return int(self._print_value)
def __str__(self):
return self.name
def __repr__(self):
return repr(self._print_value)
[docs]
class ItemVariable(VariableBase):
"""A Variable controlling an item of a dictionary or a sequence."""
def __init__(
self, obj: MutableSequence | MutableMapping, key, *args, **kwargs
) -> None:
# noinspection PyUnresolvedReferences
"""
Args:
obj: Mapping or Sequence containing the variable value,
key: Index or attribute name of the variable. A :py:class:`str`
argument is interpreted as a dictionary key. Attribute names must be
decorated with ``attr_(attrname)`` to distinguish them from directory
keys.
*args: additional sequence of indices or attribute names allowing to
extract elements deeper in the object structure.
Keyword Args:
name: Name of the Variable. If empty, a unique name is generated.
bounds: Lower and upper bounds of the variable value
delta: Initial variation step
history_length: Maximum length of the history buffer. :py:obj:`None`
means infinite.
Example:
>>> dct = {"a": 42.0, "b": [0.0, 1.0, 2.0, 3.0]}
>>> v1 = at.ItemVariable(dct, "a")
>>> v1.value
42.0
*v1* points to ``dct["a"]``
>>> v2 = at.ItemVariable(dct, "b", 1)
>>> v2.value
1.0
*v2* points to ``dct["b"][1]``
"""
super().__init__(membergetter(key, *args).accessor(obj), **kwargs)
def _setfun(self, value, obj, **_):
obj.set(value)
def _getfun(self, obj, **_):
return obj.get()
[docs]
class AttributeVariable(VariableBase):
"""A Variable controlling an attribute of an object."""
def __init__(self, obj, attrname: str, index: int | None = None, **kwargs):
# noinspection PyUnresolvedReferences
"""
Args:
obj: Object containing the variable value
attrname: attribute name of the variable
index: Index in the attribute array. Use :py:obj:`None` for
scalar attributes.
Keyword Args:
name: Name of the Variable. If empty, a unique name is generated.
bounds: Lower and upper bounds of the variable value
delta: Initial variation step
history_length: Maximum length of the history buffer. :py:obj:`None`
means infinite.
Example:
>>> ring = at.Lattice.load("hmba.mat")
>>> v3 = at.AttributeVariable(ring, "energy")
>>> v3.value
6000000000.0
*v3* points to the *"energy"* attribute of *ring*
"""
args = (attr_(attrname),) if index is None else (attr_(attrname), index)
super().__init__(membergetter(*args).accessor(obj), **kwargs)
def _setfun(self, value, obj, **_):
obj.set(value)
def _getfun(self, obj, **_):
return obj.get()
[docs]
class CustomVariable(VariableBase):
r"""A Variable with user-defined get and set functions.
This is a convenience function allowing user-defined *get* and *set*
functions. But subclassing :py:class:`.Variable` should always be preferred
for clarity and efficiency.
"""
def __init__(
self,
setfun: Callable[..., None],
getfun: Callable[..., Number],
*args,
name: str = "",
bounds: tuple[Number, Number] | None = None,
delta: Number = 1.0,
history_length: int | None = None,
**kwargs,
) -> None:
"""
Parameters:
getfun: Function for getting the variable value. Called as
:pycode:`getfun(*args, **kwargs) -> Number`.
*args* are the positional arguments given to the constructor, *kwargs*
are the keyword arguments given to the constructor augmented with the
keywords given to the :py:meth:`~.Variable.get` function.
setfun: Function for setting the variable value. Called as
:pycode:`setfun(value: Number, *args, **kwargs): None`.
*args* are the positional arguments given to the constructor, *kwargs*
are the keyword arguments given to the constructor augmented with the
keywords given to the :py:meth:`~.Variable.set` function.
name: Name of the Variable. If empty, a unique name is generated.
bounds: Lower and upper bounds of the variable value
delta: Initial variation step
*args: Variable argument list transmitted to both the *getfun*
and *setfun* functions. Such arguments can always be avoided by
using :py:func:`~functools.partial` or callable class objects.
Keyword Args:
**kwargs: Keyword arguments transmitted to both the *getfun*
and *setfun* functions. Such arguments can always be avoided by
using :py:func:`~functools.partial` or callable class objects.
"""
self.getfun = getfun
self.setfun = setfun
super().__init__(
*args,
name=name,
bounds=bounds,
delta=delta,
history_length=history_length,
**kwargs,
)
def _getfun(self, *args, **kwargs) -> Number:
return self.getfun(*args, **kwargs)
def _setfun(self, value: Number, *args, **kwargs) -> None:
self.setfun(value, *self.args, **self.kwargs)
[docs]
class VariableList(list):
"""Container for Variable objects.
:py:class:`VariableList` supports all :py:class:`list` methods, like
appending, insertion or concatenation with the "+" operator.
"""
def __getitem__(self, index):
if isinstance(index, slice):
return VariableList(super().__getitem__(index))
else:
return super().__getitem__(index)
[docs]
def get(self, **getkw) -> Sequence[float]:
r"""Get the current values of Variables.
Keyword Args:
initial: If :py:obj:`True`, set the Variables'
initial value
check_bounds: If :py:obj:`True`, raise a ValueError any value is out
of bounds
\*\*getkw: Keyword arguments to be passed to the *getfun* function.
They augment the keyword arguments given in the variable constructors.
Returns:
values: 1D array of values of all variables
"""
return np.array([var.get(**getkw) for var in self])
[docs]
def set(self, values: Iterable[float], **setkw) -> None:
r"""Set the values of Variables.
Args:
values: Iterable of values
Keyword Args:
check_bounds: If :py:obj:`True`, raise a ValueError any value is out
of bounds
\*\*setkw: Keyword arguments to be passed to the *setfun* function.
They augment the keyword arguments given in the variable constructors.
"""
for var, val in zip(self, values, strict=False):
var.set(val, **setkw)
[docs]
def increment(self, increment: Iterable[float], **setkw) -> None:
r"""Increment the values of Variables.
Args:
increment: Iterable of values
Keyword Args:
\*\*setkw: Keyword arguments to be passed to the *setfun* function.
They augment the keyword arguments given in the variable constructors.
"""
for var, incr in zip(self, increment, strict=False):
var.increment(incr, **setkw)
[docs]
def reset(self, **setkw) -> None:
r"""Reset to all variables their initial value and clear their history buffer.
Keyword Args:
\*\*setkw: Keyword arguments to be passed to the *setfun* function.
They augment the keyword arguments given in the variable constructors.
"""
for var in self:
var.reset(**setkw)
# noinspection PyProtectedMember
[docs]
def status(self, **kwargs) -> str:
"""String description of the variables."""
values = "\n".join(var._line(**kwargs) for var in self)
return "\n".join((VariableBase._header(), values))
[docs]
@contextmanager
def restore(self, initial: bool = False, **setkw):
r"""Context manager that saves and restore the variable list.
The value of the :py:class:`VariableList` is initially saved, and then restored
when exiting the context.
Keyword Args:
initial: If :py:obj:`True`, clear the history and set the variable
initial value
\*\*setkw: Keyword arguments to be passed to the *setfun* function.
They augment the keyword arguments given in the variable constructors.
"""
# print("Saving variables")
v0 = self.get(initial=initial, **setkw)
try:
yield
finally:
# print("Restoring variables")
self.set(v0, **setkw)
def __str__(self) -> str:
return self.status()
@property
def deltas(self) -> npt.NDArray[Number]:
"""delta values of the variables."""
return np.array([var.delta for var in self])
@deltas.setter
def deltas(self, value: Number | Sequence[Number]) -> None:
deltas = np.broadcast_to(value, len(self))
for var, delta in zip(self, deltas, strict=False):
var.delta = delta