Source code for at.load.utils

"""
Conversion utilities for creating pyat elements
"""

from __future__ import annotations

__all__ = [
    "element_classes",
    "element_from_dict",
    "element_to_dict",
    "find_class",
    "keep_elements",
    "keep_attributes",
    "split_ignoring_parentheses",
    "RingParam",
    "protect",
    "restore",
]

import collections
import os
import re
import sysconfig
from typing import Any
from warnings import warn
from collections.abc import Callable, Generator

import numpy as np

from at import integrators
from at.lattice import AtWarning
from at.lattice import elements as elt
from at.lattice import Lattice, Particle, Element, Marker
from at.lattice import idtable_element

_ext_suffix = sysconfig.get_config_var("EXT_SUFFIX")
_plh = "placeholder"


def _no_encoder(v):
    """type encoding for .mat files"""
    return v


def _particle(value) -> Particle:
    if isinstance(value, Particle):
        # Create from python: save_mat
        return value
    elif isinstance(value, dict):
        # Create from Matlab: load_mat
        return Particle(**value)
    else:
        return Particle(value)


def _warn(index: int, message: str, elem_dict: dict) -> None:
    name = elem_dict.get("FamName", "")
    location = f'"{name}":\n' if index is None else f'{index} ("{name}"):\n'
    warning = "".join(("In element ", location, message, f"\n{elem_dict}\n"))
    warn(AtWarning(warning), stacklevel=2)


[docs] def element_classes() -> frozenset[type[Element]]: """Build a set of all Element subclasses""" # Misses class aliases (Bend, Matrix66) def subclasses_recursive(cl): direct = cl.__subclasses__() indirect = [] for subclass in direct: indirect.extend(subclasses_recursive(subclass)) return frozenset([cl] + direct + indirect) return subclasses_recursive(Element)
class RingParam(elt.Element): """Private class for Matlab RingParam element :meta private: """ # noinspection PyProtectedMember _BUILD_ATTRIBUTES = elt.Element._BUILD_ATTRIBUTES + [ "Energy", "Periodicity", ] _conversions = dict( elt.Element._conversions, Energy=float, Periodicity=int, Particle=_particle, cell_harmnumber=float, ) # noinspection PyPep8Naming def __init__( self, FamName: str, Energy: float, Periodicity: int, **kwargs, ): if not np.isnan(float(Energy)): kwargs["Energy"] = Energy kwargs.setdefault("PassMethod", "IdentityPass") super().__init__(FamName, Periodicity=Periodicity, **kwargs) _alias_map = { "bend": elt.Dipole, "rbend": elt.Dipole, "sbend": elt.Dipole, "quad": elt.Quadrupole, "sext": elt.Sextupole, "rf": elt.RFCavity, "bpm": elt.Monitor, "ap": elt.Aperture, "ringparam": RingParam, "wig": elt.Wiggler, "matrix66": elt.M66, "M66": elt.M66, } # Map class names to Element classes _CLASS_MAP = {cls.__name__.lower(): cls for cls in element_classes()} _CLASS_MAP.update(_alias_map) # Maps passmethods to Element classes _PASS_MAP = { "BendLinearPass": elt.Dipole, "BndMPoleSymplectic4RadPass": elt.Dipole, "BndMPoleSymplectic4Pass": elt.Dipole, "QuadLinearPass": elt.Quadrupole, "StrMPoleSymplectic4Pass": elt.Multipole, "StrMPoleSymplectic4RadPass": elt.Multipole, "CorrectorPass": elt.Corrector, "CavityPass": elt.RFCavity, "RFCavityPass": elt.RFCavity, "ThinMPolePass": elt.ThinMultipole, "Matrix66Pass": elt.M66, "AperturePass": elt.Aperture, "IdTablePass": idtable_element.InsertionDeviceKickMap, "GWigSymplecticPass": elt.Wiggler, } # Maps python class name to Matlab class # Default: element_class.__name__ _mat_class = { "Dipole": "Bend", "M66": "Matrix66", } # Lattice attributes which must be dropped when writing a file _drop_attrs = { "in_file": None, "use": None, "mat_key": None, "mat_file": None, # Not used anymore... "m_file": None, "repr_file": None, } def _hasattrs(kwargs: dict, *attributes) -> bool: """Checks the presence of keys in a :py:class:`dict` Returns :py:obj:`True` if any of the ``attributes`` is in ``kwargs`` Args: kwargs: The dictionary of keyword arguments passed to the Element constructor. attributes: A list of strings, the attribute names to be checked. Returns: found (bool): :py:obj:`True` if the element has any of the specified attributes. """ for attribute in attributes: if attribute in kwargs: return True return False
[docs] def keep_attributes(ring: Lattice): """Remove Lattice attributes which must not be saved on file""" return {k: v for k, v in ring.attrs.items() if _drop_attrs.get(k, k) is not None}
[docs] def keep_elements(ring: Lattice) -> Generator[Element, None, None]: """Remove the 'RingParam' Marker""" for elem in ring: if not (isinstance(elem, Marker) and getattr(elem, "tag", None) == "RingParam"): yield elem
def _from_contents(elem: dict) -> type[Element]: """Deduce the element class from its contents""" def low_order(key): polynom = np.array(elem[key], dtype=np.float64).reshape(-1) try: low = np.where(polynom != 0.0)[0][0] except IndexError: low = -1 return low length = float(elem.get("Length", 0.0)) pass_method = elem.get("PassMethod", "") if _hasattrs( elem, "FullGap", "FringeInt1", "FringeInt2", "gK", "EntranceAngle", "ExitAngle" ): return elt.Dipole elif _hasattrs(elem, "Voltage", "Frequency", "HarmNumber", "PhaseLag", "TimeLag"): return elt.RFCavity elif _hasattrs(elem, "Periodicity"): # noinspection PyProtectedMember return RingParam elif _hasattrs(elem, "Limits"): return elt.Aperture elif _hasattrs(elem, "M66"): return elt.M66 elif _hasattrs(elem, "K"): return elt.Quadrupole elif _hasattrs(elem, "PolynomB", "PolynomA"): loworder = low_order("PolynomB") if loworder == 1: return elt.Quadrupole elif loworder == 2: return elt.Sextupole elif loworder == 3: return elt.Octupole elif pass_method.startswith("StrMPoleSymplectic4") or (length > 0): return elt.Multipole else: return elt.ThinMultipole elif _hasattrs(elem, "KickAngle"): return elt.Corrector elif length > 0.0: return elt.Drift elif _hasattrs(elem, "GCR"): return elt.Monitor elif pass_method == "IdentityPass": return elt.Marker else: return elt.Element
[docs] def find_class( elem_dict: dict, quiet: bool = False, index: int | None = None ) -> type(Element): """Deduce the class of an element from its attributes `find_class` looks first at the "Class" field, if existing. It then tries to deduce the class from "FamName", from "PassMethod", and finally form the element contents. Args: elem_dict: The dictionary of keyword arguments passed to the Element constructor. quiet: Suppress the warning for non-standard classes index: Element index in the lattice Returns: element_class: The guessed Class name """ def check_class(clname): if clname: _warn(index, f"Class '{clname}' does not exist.", elem_dict) def check_pass(passm): if not passm: _warn(index, "No PassMethod provided.", elem_dict) elif not passm.endswith("Pass"): message = ( f"Invalid PassMethod '{passm}': " "provided pass methods should end in 'Pass'." ) _warn(index, message, elem_dict) class_name = elem_dict.pop("Class", "") # try from class name cls = _CLASS_MAP.get(class_name.lower(), None) if cls is not None: return cls elif not quiet: check_class(class_name) elname = elem_dict.get("FamName", "") # try from element name cls = _CLASS_MAP.get(elname.lower(), None) if cls is not None: return cls pass_method = elem_dict.get("PassMethod", "") # try from passmethod cls = _PASS_MAP.get(pass_method, None) if cls is not None: return cls elif not quiet: check_pass(pass_method) return _from_contents(elem_dict) # look for contents
[docs] def element_from_dict( elem_dict: dict, index: int | None = None, check: bool = True, quiet: bool = False, ) -> Element: """Builds an :py:class:`.Element` from a dictionary of attributes Parameters: elem_dict: Dictionary of element attributes index: Element index check: Check the compatibility of class and PassMethod quiet: Suppress the warning for non-standard classes Returns: elem (Element): new :py:class:`.Element` """ # noinspection PyShadowingNames def sanitise_class(index, cls, elem_dict): """Checks that the Class and PassMethod of the element are a valid combination. Some Classes and PassMethods are incompatible and would raise errors during calculation, so we send a warning here. Args: index: element index cls: Proposed class elem_dict: The dictionary of keyword arguments passed to the Element constructor. Raises: AttributeError: if the PassMethod and Class are incompatible. """ class_name = cls.__name__ pass_method = elem_dict.get("PassMethod") if pass_method is not None: pass_to_class = _PASS_MAP.get(pass_method) length = float(elem_dict.get("Length", 0.0)) file_name = pass_method + _ext_suffix file_path = os.path.join(integrators.__path__[0], file_name) if not os.path.isfile(os.path.realpath(file_path)): message = f"PassMethod {pass_method} is missing {file_name}." _warn(index, message, elem_dict) elif (pass_method == "IdentityPass") and (length != 0.0): message = ( f"PassMethod {pass_method} is not compatible with length {length}." ) _warn(index, message, elem_dict) elif pass_to_class is not None: if not issubclass(cls, pass_to_class): message = ( f"PassMethod {pass_method} is not compatible " f"with Class {class_name}." ) _warn(index, message, elem_dict) cls = find_class(elem_dict, quiet=quiet, index=index) if check: sanitise_class(index, cls, elem_dict) # Remove mandatory attributes from the keyword arguments. # Create list rather than generator to ensure that elements are removed # from elem_dict. elem_args = [elem_dict.pop(attr, None) for attr in cls._BUILD_ATTRIBUTES] element = cls(*(arg for arg in elem_args if arg is not None), **elem_dict) return element
[docs] def element_to_dict(elem: Element, encoder: Callable[[Any], Any] = _no_encoder) -> dict: """Convert a :py:class:`.Element` to a :py:class:`dict` Parameters: elem: :py:class:`.Element` encoder: data converter Returns: dct (dict): Dictionary of :py:class:`.Element` attributes """ dct = {k: encoder(v) for k, v in elem.items()} class_name = elem.__class__.__name__ dct["Class"] = _mat_class.get(class_name, class_name) return dct
[docs] def split_ignoring_parentheses( string: str, delimiter: str = ",", fence: tuple[str, str] = ("\\(", "\\)"), maxsplit: int = -1, ) -> list[str]: """Split a string while keeping protected expressions intact Example: "l=0,hom(4,0.0,0)" -> ["l=0", "hom(4,0.0,0)"] """ substituted, matches = protect(string, fence=fence) parts = substituted.split(delimiter, maxsplit=maxsplit) return restore(matches, *parts)
[docs] def protect( string: str, fence: tuple[str, str] = ('"', '"'), *, placeholder: str = _plh, ): inf, outf = fence pattern = f"{inf}[^{inf}]*?{outf}" substituted = string[:] matches = collections.deque(re.finditer(pattern, string)) for match in matches: substituted = substituted.replace(match.group(), placeholder, 1) return substituted, (placeholder, matches)
[docs] def restore(replmatch, *parts): def rep(part): while placeholder in part: next_match = matches.popleft() part = part.replace(placeholder, next_match.group(), 1) return part placeholder, matches = replmatch replaced_parts = [rep(part) for part in parts] assert not matches return replaced_parts
Element.from_dict = staticmethod(element_from_dict) Element.to_dict = element_to_dict