Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/pals/kinds/Lattice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from pydantic import model_validator, Field
from typing import Annotated, List, Literal, Union

from .BeamLine import BeamLine
from .mixin import BaseElement
from ..functions import load_file_to_dict, store_dict_to_file


class Lattice(BaseElement):
"""A line of elements and/or other lines"""
Comment thread
ax3l marked this conversation as resolved.
Outdated

kind: Literal["Lattice"] = "Lattice"

branches: List[Annotated[Union[BeamLine], Field(discriminator="kind")]]

@model_validator(mode="before")
@classmethod
def unpack_json_structure(cls, data):
"""Deserialize the JSON/YAML/...-like dict for Lattice elements"""
from pals.kinds.mixin.all_element_mixin import unpack_element_list_structure

return unpack_element_list_structure(data, "branches", "branches")

def model_dump(self, *args, **kwargs):
"""Custom model dump for Lattice to handle element list formatting"""
from pals.kinds.mixin.all_element_mixin import dump_element_list

return dump_element_list(self, "branches", *args, **kwargs)

@staticmethod
def from_file(filename: str) -> "Lattice":
"""Load a Lattice from a text file"""
pals_dict = load_file_to_dict(filename)
return Lattice(**pals_dict)

def to_file(self, filename: str):
"""Save a Lattice to a text file"""
pals_dict = self.model_dump()
store_dict_to_file(filename, pals_dict)
Comment thread
EZoni marked this conversation as resolved.
2 changes: 2 additions & 0 deletions src/pals/kinds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .ACKicker import ACKicker # noqa: F401
from .BeamBeam import BeamBeam # noqa: F401
from .BeamLine import BeamLine # noqa: F401
from .Lattice import Lattice # noqa: F401
from .BeginningEle import BeginningEle # noqa: F401
from .Converter import Converter # noqa: F401
from .CrabCavity import CrabCavity # noqa: F401
Expand Down Expand Up @@ -39,3 +40,4 @@
# Rebuild pydantic models that depend on other classes
UnionEle.model_rebuild()
BeamLine.model_rebuild()
Lattice.model_rebuild()
1 change: 1 addition & 0 deletions src/pals/kinds/all_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
def get_all_element_types(extra_types: tuple = None):
"""Return a tuple of all element types that can be used in BeamLine or UnionEle."""
element_types = (
"Lattice", # Forward reference to handle circular import
"BeamLine", # Forward reference to handle circular import
"UnionEle", # Forward reference to handle circular import
ACKicker,
Expand Down
26 changes: 26 additions & 0 deletions tests/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,29 @@ def test_UnionEle():
assert len(element_with_children.elements) == 2
assert element_with_children.elements[0].name == "m1"
assert element_with_children.elements[1].name == "d1"


def test_Lattice():
# Create first line with one base element
element1 = pals.Marker(name="element1")
line1 = pals.BeamLine(name="line1", line=[element1])
assert line1.line == [element1]
# Extend first line with one thick element
element2 = pals.Drift(name="element2", length=2.0)
line1.line.extend([element2])
assert line1.line == [element1, element2]
# Create second line with one drift element
element3 = pals.Drift(name="element3", length=3.0)
line2 = pals.BeamLine(name="line2", line=[element3])
# Extend first line with second line
line1.line.extend(line2.line)
assert line1.line == [element1, element2, element3]
Comment thread
EZoni marked this conversation as resolved.
# Build lattice with two lines
lattice = pals.Lattice(name="lattice", branches=[line1, line2])
assert lattice.name == "lattice"
assert lattice.kind == "Lattice"
assert len(lattice.branches) == 2
assert lattice.branches[0].name == "line1"
assert lattice.branches[1].name == "line2"
assert lattice.branches[0].line == [element1, element2, element3]
assert lattice.branches[1].line == [element3]