Source code for at.lattice.elements.element_object

"""Base :py:class:`.Element` object."""

from __future__ import annotations

__all__ = ["Element", "ReferencePoint", "transform_attr", "transform_options"]

import re
from collections.abc import Callable, Generator
from contextlib import suppress
from enum import Enum
from copy import copy, deepcopy
from typing import Any, ClassVar

import numpy as np

from .conversions import _array, _array66, _int, _float
from ..parambase import ParamDef, _nop
from ..parameters import _ACCEPTED, Param, ParamArray


[docs] class ReferencePoint(Enum): """Definition of the reference point for the geometric transformations.""" CENTRE = 0 #: Origin at the centre of the element. ENTRANCE = 1 #: Origin at the entrance of the element.
class _TransFormOptions: referencepoint = ReferencePoint.CENTRE transform_options = _TransFormOptions() transform_attr = ["dx", "dy", "dz", "pitch", "yaw", "tilt", "reference"] def _no_encoder(v): """type encoding for .mat files.""" return v
[docs] class Element: """Base class for AT elements.""" _BUILD_ATTRIBUTES: ClassVar[list[str]] = ["FamName"] _conversions: ClassVar[dict] = { "FamName": str, "PassMethod": str, "Length": _float, "R1": _array66, "R2": _array66, "T1": lambda v: _array(v, (6,)), "T2": lambda v: _array(v, (6,)), "RApertures": lambda v: _array(v, (4,)), "EApertures": lambda v: _array(v, (2,)), "KickAngle": lambda v: _array(v, (2,)), "PolynomB": _array, "PolynomA": _array, "BendingAngle": _float, "MaxOrder": _int, "NumIntSteps": lambda v: _int(v, vmin=0), "Energy": _float, } _file_classname: ClassVar[str] _entrance_fields: ClassVar[list[str]] = ["T1", "R1"] _exit_fields: ClassVar[list[str]] = ["T2", "R2"] _no_swap = _entrance_fields + _exit_fields __slots__ = ["__dict__", "_parameters"] # For the moment keep empty for Matlab compatibility # _drop_attr = ["T1", "T2", "R1", "R2"] _drop_attr: ClassVar[list[str]] = [] _convert_attr: ClassVar[dict] = { "_dx": "dx", "_dy": "dy", "_dz": "dz", "_pitch": "pitch", "_yaw": "yaw", "_tilt": "tilt", "_tilt_frame": "tilt_frame", "_referencepoint": "reference" } def __new__(cls, *args, **kwargs): obj = super().__new__(cls) # _parameters must be created before any other attribute is set obj._parameters = {} return obj def __init__(self, family_name: str, **kwargs): """ Parameters: family_name: Name of the element. All keywords will be set as attributes of the element """ self.FamName = family_name self.Length = kwargs.pop("Length", 0.0) self.PassMethod = kwargs.pop("PassMethod", "IdentityPass") self.update(kwargs) def __setattr__(self, attrname: str, value: Any) -> None: """Override __setattr__ to handle parameter conversions. This method applies the appropriate conversion function to the value before setting it as an attribute. """ # Get the conversion function for this attribute or use _nop (no operation) conversion = self._conversions.get(attrname, _nop) try: # If the value is a parameter, set its conversion function if isinstance(value, _ACCEPTED): value.set_conversion(conversion) # Otherwise, apply the conversion to the value elif not isinstance(value, ParamArray): value = conversion(value) except Exception as exc: # Conversion failed exc.args = (f"{self._ident(attrname)}: {exc}",) raise else: # Conversion succeeded if isinstance(value, (ParamDef, ParamArray)): # Store the parameter and remove the attribute self._parameters[attrname] = value with suppress(AttributeError): # Don't care if the attribute does not exist object.__delattr__(self, attrname) else: # Store the attribute and remove the parameter object.__setattr__(self, attrname, value) with suppress(KeyError): # Don't care if the parameter does not exist del self._parameters[attrname] def __getattr__(self, attrname: str) -> Any: """Override __getattr__ to handle parameter values. This method returns the value of parameters instead of the parameter objects themselves when accessing attributes. """ try: return self._parameters[attrname].fast_value() except KeyError as exc: cl = self.__class__.__name__ el = object.__getattribute__(self, "FamName") msg = f"{cl}({el!r}) has no attribute {attrname!r}" raise AttributeError(msg) from exc def __delattr__(self, attrname: str) -> None: """Override __delattr__ to handle parameter deletions.""" try: object.__delattr__(self, attrname) except AttributeError: try: del self._parameters[attrname] except KeyError as exc: cl = self.__class__.__name__ el = object.__getattribute__(self, "FamName") msg = f"{cl}({el!r}) has no attribute {attrname!r}" raise AttributeError(msg) from exc def __str__(self): return "\n".join( [self.__class__.__name__ + ":"] + [f"{k:>14}: {v!s}" for k, v in self.items(freeze=False)] ) def __repr__(self): clsname, args, kwargs = self.definition keywords = [f"{arg!r}" for arg in args] keywords += [f"{k}={v!r}" for k, v in kwargs.items()] args = re.sub(r"\n\s*", " ", ", ".join(keywords)) return f"{clsname}({args})" def __getstate__(self): # For pickling Elements: make a copy of parameters return self.__dict__, {"_parameters": self._parameters.copy()} def _ident(self, attrname: str | None = None, index: int | None = None): """Return an element's identifier for error messages.""" if attrname is None: return f"{self.__class__.__name__}({self.FamName!r})" elif index is None: return f"{self.__class__.__name__}({self.FamName!r}).{attrname}" else: return f"{self.__class__.__name__}({self.FamName!r}).{attrname}[{index}]"
[docs] @classmethod def subclasses(cls) -> Generator[type[Element], None, None]: """Yields all the class subclasses. Some classes may appear several times because of diamond-shape inheritance """ for subclass in cls.__subclasses__(): yield from subclass.subclasses() yield cls
[docs] def keys(self): """Return a set of all attribute names.""" v = set(vars(self).keys()) v.update(self._parameters.keys()) return v
[docs] def to_dict(self, freeze: bool = True): """Return a copy of the element parameters.""" if freeze: return {k: getattr(self, k) for k in self.keys()} else: v = vars(self).copy() v.update(self._parameters) return v
[docs] def get_parameter(self, attrname: str, index: int | None = None) -> Any: """Extract a parameter of an element. Unlike :py:func:`getattr`, :py:func:`get_parameter` returns the parameter itself instead of its value. If the item is not a parameter, both functions are equivalent, the value is returned. Args: attrname: Attribute name index: Index in an array attribute. If :py:obj:`None`, the whole attribute is returned Returns: The parameter object or attribute value. """ try: attr = self.__dict__[attrname] except KeyError: try: attr = self._parameters[attrname] except KeyError: msg = f"{self._ident()} has no attribute {attrname!r}" raise AttributeError(msg) from None if index is not None: try: attr = attr[index] except IndexError as exc: msg = f"{self._ident(attrname)}: {exc}" raise IndexError(msg) from None return attr
[docs] def equals(self, other) -> bool: """Whether an element is equivalent to another. This implementation was found to be too slow for the generic __eq__ method when comparing lattices. """ return repr(self) == repr(other)
[docs] def divide(self, frac) -> list[Element]: # noinspection PyUnresolvedReferences """split the element in len(frac) pieces whose length is frac[i]*self.Length. Parameters: frac: length of each slice expressed as a fraction of the initial length. ``sum(frac)`` may differ from 1. Returns: elem_list: a list of elements equivalent to the original. Example: >>> Drift("dr", 0.5).divide([0.2, 0.6, 0.2]) [Drift('dr', 0.1), Drift('dr', 0.3), Drift('dr', 0.1)] """ # Bx default, the element is indivisible return [self]
[docs] def swap_faces(self, copy=False): """Swap the faces of an element, alignment errors are ignored.""" def swapattr(element, attro, attri): val = element.get_parameter(attri) # get the parameter itself delattr(element, attri) return attro, val el = self.copy() if copy else self # Remove and swap entrance and exit attributes attrs = el.keys() fin = dict( swapattr(el, kout, kin) for kin, kout in zip(el._entrance_fields, el._exit_fields, strict=True) if kin in attrs and kin not in el._no_swap ) fout = dict( swapattr(el, kin, kout) for kin, kout in zip(el._entrance_fields, el._exit_fields, strict=True) if kout in attrs and kout not in el._no_swap ) # Apply swapped entrance and exit attributes for key, value in fin.items(): setattr(el, key, value) for key, value in fout.items(): setattr(el, key, value) return el if copy else None
[docs] def update(self, *args, **kwargs): """ update(**kwargs) update(mapping, **kwargs) update(iterable, **kwargs) Update the element attributes with the given arguments. """ attrs = dict(*args, **kwargs) for key, value in attrs.items(): setattr(self, key, value)
[docs] def copy(self) -> Element: """Return a shallow copy of the element.""" return copy(self)
[docs] def deepcopy(self) -> Element: """Return a deep copy of the element.""" return deepcopy(self)
@property def definition(self) -> tuple[str, tuple, dict]: """tuple (class_name, args, kwargs) defining the element.""" attrs = dict(self.items()) arguments = tuple( attrs.pop(k, getattr(self, k)) for k in self._BUILD_ATTRIBUTES ) defelem = self.__class__(*arguments) keywords = { k: v for k, v in attrs.items() if not np.array_equal(v, getattr(defelem, k, None)) } return self.__class__.__name__, arguments, keywords
[docs] def items(self, freeze: bool = True) -> Generator[tuple[str, Any], None, None]: """Iterates through the data members.""" v = self.to_dict(freeze=freeze) for k in ["FamName", "Length", "PassMethod"]: yield k, v.pop(k) for k, val in sorted(v.items()): yield k, val
[docs] def is_compatible(self, other: Element) -> bool: """Checks if another :py:class:`Element` can be merged.""" return False
[docs] def merge(self, other) -> None: """Merge another element.""" if not self.is_compatible(other): badname = getattr(other, "FamName", type(other)) msg = f"Cannot merge {self.FamName} and {badname}" raise TypeError(msg)
# noinspection PyMethodMayBeStatic def _get_longt_motion(self): return False # noinspection PyMethodMayBeStatic def _get_collective(self): return False
[docs] def to_file(self, encoder: Callable[[Any], Any] = _no_encoder) -> dict: """Convert a :py:class:`.Element` to a :py:class:`dict`. Parameters: encoder: data converter Returns: dct (dict): Dictionary of :py:class:`.Element` attributes """ dct = { self._convert_attr.get(k, k): encoder(v) for k, v in self.items() if k not in self._drop_attr } dct["Class"] = getattr(self, "_file_classname", self.__class__.__name__) return dct
[docs] @classmethod def from_file(cls, elem_dict: dict[str, Any]) -> Element: """Generate an Element from a dictionary of attributes. Parameters: elem_dict: Dictionary of :py:class:`.Element` attributes Returns: elem(Element): :py:class:`.Element` """ # Separate the positional arguments elem_args = [elem_dict.pop(attr, None) for attr in cls._BUILD_ATTRIBUTES] # Remove the transformation attributes trs = {attr: elem_dict.pop(attr, None) for attr in transform_attr} # Create the element element = cls(*(arg for arg in elem_args if arg is not None), **elem_dict) # Transform the element if necessary if not all(v is None for v in trs.values()): refval = trs.pop("reference", transform_options.referencepoint.value) element.transform(reference=ReferencePoint(refval), **trs) return element
@property def longt_motion(self) -> bool: """:py:obj:`True` if the element affects the longitudinal motion.""" return self._get_longt_motion() @property def is_collective(self) -> bool: """:py:obj:`True` if the element involves collective effects.""" return self._get_collective()
[docs] def set_parameter( self, attrname: str, value: Any, index: int | None = None ) -> None: """Set an element's parameter. This allows setting a parameter into an attribute or an item of an array attribute. Args: attrname: Attribute name value: Parameter or value to be set index: Index into an array attribute. If *value* is a parameter, the array attribute is converted to a :py:class:`.ParamArray`. Raises: IndexError: If the provided index is out of bounds for the array attribute AttributeError: If the attribute doesn't exist """ def set_array_item(arr: np.ndarray, idx: int, val: Any) -> None: """Helper function to set an item in an array.""" try: arr[idx] = val except IndexError as exc: exc.args = (f"{self._ident(attrname)}: {exc}",) raise # Set the entire attribute if index is None: setattr(self, attrname, value) # Set a specific index in an array attribute else: array = self.get_parameter(attrname) if not isinstance(array, ParamArray) and isinstance(value, ParamDef): # Convert the array to a ParamArray if it's not already one array = ParamArray(array, shape=array.shape, dtype=array.dtype) set_array_item(array, index, value) setattr(self, attrname, array) else: set_array_item(array, index, value)
[docs] def is_parameterised( self, attrname: str | None = None, index: int | None = None ) -> bool: """Check for the parameterisation of an element. Args: attrname: Attribute name. If :py:obj:`None`, checks if any attribute is parameterised index: Index in an array attribute. If :py:obj:`None`, tests the whole attribute Returns: True if the attribute, or array item is parameterised, False otherwise """ # Check if any attribute is parameterised # Works for AT and MADX parameters if attrname is None: if len(self._parameters) > 0: return True # Check for MADX parameters return any(self.is_parameterised(attribute) for attribute in self.__dict__) # Get the attribute or specific index attribute = self.get_parameter(attrname, index=index) return isinstance(attribute, (ParamDef, ParamArray))
[docs] def parameterise( self, attrname: str, index: int | None = None, name: str = "" ) -> _ACCEPTED: """Convert attribute to parameter preserving value. The value of the attribute is kept unchanged. If the attribute is already parameterised, the existing parameter is returned. Args: attrname: Attribute name index: Index in an array. If :py:obj:`None`, the whole attribute is parameterised name: Name of the created parameter Returns: A :py:class:`.ParamArray` for an array attribute, a :py:class:`.Param` for a scalar attribute or an item in an array attribute Raises: TypeError: If the attribute value is not a valid parameter type (Number) IndexError: If the index is out of bounds for an array attribute AttributeError: If the attribute does not exist """ # Get the current value of the attribute or array element current_value = self.get_parameter(attrname, index=index) # If it's already a parameter, return it if isinstance(current_value, _ACCEPTED): return current_value # Create a new parameter with the current value try: param = Param(current_value, name=name) except TypeError as exc: exc.args = (f"Cannot parameterise {self._ident(attrname)}: {exc}",) raise # Set the parameter in the element self.set_parameter(attrname, param, index=index) return param
[docs] def unparameterise( self, attrname: str | None = None, index: int | None = None ) -> None: """Replace parameters with their current values. This function replaces parameters with their current values, effectively "freezing" them. This is useful when you want to convert a parameterised element back to a regular element with fixed values. Args: attrname: Attribute name. If :py:obj:`None`, freezes all attributes index: Index in an array. If :py:obj:`None`, freezes the whole attribute Attributes which are not parameters are silently ignored. """ # Worls for AT and MADX parameters if attrname is None: # freeze all the attributes # make a copy of the parameters dict to avoid modifications during iteration for name, attr in list(self._parameters.items()): setattr(self, name, attr.value) else: attr = self.get_parameter(attrname) if not isinstance(attr, (ParamDef, ParamArray)): # silently ignore non-parameter attributes return if index is None: # freeze a scalar attribute setattr(self, attrname, attr.value) else: # freeze an item in an array attribute item = attr[index] if isinstance(item, ParamDef): attr[index] = item.value if not any(isinstance(item, ParamDef) for item in attr.flat): # freeze the whole array attribute if no parameter left setattr(self, attrname, attr.value)