"""Base :py:class:`.Element` object."""
from __future__ import annotations
__all__ = ["Element", "ReferencePoint", "transform_attr", "transform_options"]
import re
from collections.abc import Callable, Generator
from enum import Enum
from copy import copy, deepcopy
from typing import Any, ClassVar
import numpy as np
from .conversions import _array, _array66, _int, _float
[docs]
class ReferencePoint(Enum):
"""Definition of the reference point for the geometric transformations."""
CENTRE = 0 #: Origin at the centre of the element.
ENTRANCE = 1 #: Origin at the entrance of the element.
class _TransFormOptions:
referencepoint = ReferencePoint.CENTRE
transform_options = _TransFormOptions()
transform_attr = ["dx", "dy", "dz", "pitch", "yaw", "tilt", "reference"]
def _nop(value):
return value
def _no_encoder(v):
"""type encoding for .mat files."""
return v
[docs]
class Element:
"""Base class for AT elements."""
_BUILD_ATTRIBUTES: ClassVar[list[str]] = ["FamName"]
_conversions: ClassVar[dict] = {
"FamName": str,
"PassMethod": str,
"Length": _float,
"R1": _array66,
"R2": _array66,
"T1": lambda v: _array(v, (6,)),
"T2": lambda v: _array(v, (6,)),
"RApertures": lambda v: _array(v, (4,)),
"EApertures": lambda v: _array(v, (2,)),
"KickAngle": lambda v: _array(v, (2,)),
"PolynomB": _array,
"PolynomA": _array,
"BendingAngle": _float,
"MaxOrder": _int,
"NumIntSteps": lambda v: _int(v, vmin=0),
"Energy": _float,
}
_file_classname: ClassVar[str]
_entrance_fields: ClassVar[list[str]] = ["T1", "R1"]
_exit_fields: ClassVar[list[str]] = ["T2", "R2"]
_no_swap = _entrance_fields + _exit_fields
# For the moment keep empty for Matlab compatibility
# _drop_attr = ["T1", "T2", "R1", "R2"]
_drop_attr: ClassVar[list[str]] = []
_convert_attr: ClassVar[dict] = {
"_dx": "dx",
"_dy": "dy",
"_dz": "dz",
"_pitch": "pitch",
"_yaw": "yaw",
"_tilt": "tilt",
"_tilt_frame": "tilt_frame",
"_referencepoint": "reference"
}
def __init__(self, family_name: str, **kwargs):
"""
Parameters:
family_name: Name of the element.
All keywords will be set as attributes of the element
"""
self.FamName = family_name
self.Length = kwargs.pop("Length", 0.0)
self.PassMethod = kwargs.pop("PassMethod", "IdentityPass")
self.update(kwargs)
def __setattr__(self, attrname: str, value: Any) -> None:
"""Override __setattr__ to handle parameter conversions.
This method applies the appropriate conversion function to the value
before setting it as an attribute.
"""
try:
value = self._conversions.get(attrname, _nop)(value)
except Exception as exc:
# Conversion failed
exc.args = (f"{self._ident(attrname)}: {exc}",)
raise
else:
# Conversion succeeded
super().__setattr__(attrname, value)
def __str__(self):
return "\n".join(
[self.__class__.__name__ + ":"]
+ [f"{k:>14}: {v!s}" for k, v in self.items()]
)
def __repr__(self):
clsname, args, kwargs = self.definition
keywords = [f"{arg!r}" for arg in args]
keywords += [f"{k}={v!r}" for k, v in kwargs.items()]
args = re.sub(r"\n\s*", " ", ", ".join(keywords))
return f"{clsname}({args})"
def _ident(self, attrname: str | None = None, index: int | None = None):
"""Return an element's identifier for error messages."""
if attrname is None:
return f"{self.__class__.__name__}({self.FamName!r})"
elif index is None:
return f"{self.__class__.__name__}({self.FamName!r}).{attrname}"
else:
return f"{self.__class__.__name__}({self.FamName!r}).{attrname}[{index}]"
[docs]
@classmethod
def subclasses(cls) -> Generator[type[Element], None, None]:
"""Yields all the class subclasses.
Some classes may appear several times because of diamond-shape inheritance
"""
for subclass in cls.__subclasses__():
yield from subclass.subclasses()
yield cls
[docs]
def to_dict(self):
return vars(self).copy()
[docs]
def equals(self, other) -> bool:
"""Whether an element is equivalent to another.
This implementation was found to be too slow for the generic
__eq__ method when comparing lattices.
"""
return repr(self) == repr(other)
[docs]
def divide(self, frac) -> list[Element]:
# noinspection PyUnresolvedReferences
"""split the element in len(frac) pieces whose length is frac[i]*self.Length.
Parameters:
frac: length of each slice expressed as a fraction of the
initial length. ``sum(frac)`` may differ from 1.
Returns:
elem_list: a list of elements equivalent to the original.
Example:
>>> Drift("dr", 0.5).divide([0.2, 0.6, 0.2])
[Drift('dr', 0.1), Drift('dr', 0.3), Drift('dr', 0.1)]
"""
# Bx default, the element is indivisible
return [self]
[docs]
def swap_faces(self, copy=False):
"""Swap the faces of an element, alignment errors are ignored."""
def swapattr(element, attro, attri):
val = getattr(element, attri)
delattr(element, attri)
return attro, val
el = self.copy() if copy else self
# Remove and swap entrance and exit attributes
fin = dict(
swapattr(el, kout, kin)
for kin, kout in zip(el._entrance_fields, el._exit_fields, strict=True)
if kin in vars(el) and kin not in el._no_swap
)
fout = dict(
swapattr(el, kin, kout)
for kin, kout in zip(el._entrance_fields, el._exit_fields, strict=True)
if kout in vars(el) and kout not in el._no_swap
)
# Apply swapped entrance and exit attributes
for key, value in fin.items():
setattr(el, key, value)
for key, value in fout.items():
setattr(el, key, value)
return el if copy else None
[docs]
def update(self, *args, **kwargs):
"""
update(**kwargs)
update(mapping, **kwargs)
update(iterable, **kwargs)
Update the element attributes with the given arguments.
"""
attrs = dict(*args, **kwargs)
for key, value in attrs.items():
setattr(self, key, value)
[docs]
def copy(self) -> Element:
"""Return a shallow copy of the element."""
return copy(self)
[docs]
def deepcopy(self) -> Element:
"""Return a deep copy of the element."""
return deepcopy(self)
@property
def definition(self) -> tuple[str, tuple, dict]:
"""tuple (class_name, args, kwargs) defining the element."""
attrs = dict(self.items())
arguments = tuple(
attrs.pop(k, getattr(self, k)) for k in self._BUILD_ATTRIBUTES
)
defelem = self.__class__(*arguments)
keywords = {
k: v
for k, v in attrs.items()
if not np.array_equal(v, getattr(defelem, k, None))
}
return self.__class__.__name__, arguments, keywords
[docs]
def items(self) -> Generator[tuple[str, Any], None, None]:
"""Iterates through the data members."""
v = self.to_dict()
for k in ["FamName", "Length", "PassMethod"]:
yield k, v.pop(k)
for k, val in sorted(v.items()):
yield k, val
[docs]
def is_compatible(self, other: Element) -> bool:
"""Checks if another :py:class:`Element` can be merged."""
return False
[docs]
def merge(self, other) -> None:
"""Merge another element."""
if not self.is_compatible(other):
badname = getattr(other, "FamName", type(other))
msg = f"Cannot merge {self.FamName} and {badname}"
raise TypeError(msg)
# noinspection PyMethodMayBeStatic
def _get_longt_motion(self):
return False
# noinspection PyMethodMayBeStatic
def _get_collective(self):
return False
[docs]
def to_file(self, encoder: Callable[[Any], Any] = _no_encoder) -> dict:
"""Convert a :py:class:`.Element` to a :py:class:`dict`.
Parameters:
encoder: data converter
Returns:
dct (dict): Dictionary of :py:class:`.Element` attributes
"""
dct = {
self._convert_attr.get(k, k): encoder(v)
for k, v in self.items()
if k not in self._drop_attr
}
dct["Class"] = getattr(self, "_file_classname", self.__class__.__name__)
return dct
[docs]
@classmethod
def from_file(cls, elem_dict: dict[str, Any]) -> Element:
"""Generate an Element from a dictionary of attributes.
Parameters:
elem_dict: Dictionary of :py:class:`.Element` attributes
Returns:
elem(Element): :py:class:`.Element`
"""
# Separate the positional arguments
elem_args = [elem_dict.pop(attr, None) for attr in cls._BUILD_ATTRIBUTES]
# Remove the transformation attributes
trs = {attr: elem_dict.pop(attr, None) for attr in transform_attr}
# Create the element
element = cls(*(arg for arg in elem_args if arg is not None), **elem_dict)
# Transform the element if necessary
if not all(v is None for v in trs.values()):
refval = trs.pop("reference", transform_options.referencepoint.value)
element.transform(reference=ReferencePoint(refval), **trs)
return element
@property
def longt_motion(self) -> bool:
""":py:obj:`True` if the element affects the longitudinal motion."""
return self._get_longt_motion()
@property
def is_collective(self) -> bool:
""":py:obj:`True` if the element involves collective effects."""
return self._get_collective()