Skip to content

Commit 7ff3bb6

Browse files
EZoniax3l
andauthored
Add Lattice class (#59)
Co-authored-by: Axel Huebl <axel.huebl@plasma.ninja>
1 parent 7c45243 commit 7ff3bb6

6 files changed

Lines changed: 130 additions & 45 deletions

File tree

examples/fodo.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pals import Drift
33
from pals import Quadrupole
44
from pals import BeamLine
5+
from pals import Lattice
56

67

78
def main():
@@ -42,26 +43,31 @@ def main():
4243
drift3,
4344
],
4445
)
46+
# Create lattice with the line as a branch
47+
lattice = Lattice(
48+
name="fodo_lattice",
49+
branches=[line],
50+
)
4551

4652
# Serialize to YAML
4753
yaml_file = "examples_fodo.pals.yaml"
48-
line.to_file(yaml_file)
54+
lattice.to_file(yaml_file)
4955

5056
# Read YAML data from file
51-
loaded_line = BeamLine.from_file(yaml_file)
57+
loaded_lattice = Lattice.from_file(yaml_file)
5258

5359
# Validate loaded data
54-
assert line == loaded_line
60+
assert lattice == loaded_lattice
5561

5662
# Serialize to JSON
5763
json_file = "examples_fodo.pals.json"
58-
line.to_file(json_file)
64+
lattice.to_file(json_file)
5965

6066
# Read JSON data from file
61-
loaded_line = BeamLine.from_file(json_file)
67+
loaded_lattice = Lattice.from_file(json_file)
6268

6369
# Validate loaded data
64-
assert line == loaded_line
70+
assert lattice == loaded_lattice
6571

6672

6773
if __name__ == "__main__":

src/pals/kinds/Lattice.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from pydantic import model_validator, Field
2+
from typing import Annotated, List, Literal, Union
3+
4+
from .BeamLine import BeamLine
5+
from .mixin import BaseElement
6+
from ..functions import load_file_to_dict, store_dict_to_file
7+
8+
9+
class Lattice(BaseElement):
10+
"""A lattice combines beamlines"""
11+
12+
kind: Literal["Lattice"] = "Lattice"
13+
14+
branches: List[Annotated[Union[BeamLine], Field(discriminator="kind")]]
15+
16+
@model_validator(mode="before")
17+
@classmethod
18+
def unpack_json_structure(cls, data):
19+
"""Deserialize the JSON/YAML/...-like dict for Lattice elements"""
20+
from pals.kinds.mixin.all_element_mixin import unpack_element_list_structure
21+
22+
return unpack_element_list_structure(data, "branches", "branches")
23+
24+
def model_dump(self, *args, **kwargs):
25+
"""Custom model dump for Lattice to handle element list formatting"""
26+
from pals.kinds.mixin.all_element_mixin import dump_element_list
27+
28+
return dump_element_list(self, "branches", *args, **kwargs)
29+
30+
@staticmethod
31+
def from_file(filename: str) -> "Lattice":
32+
"""Load a Lattice from a text file"""
33+
pals_dict = load_file_to_dict(filename)
34+
return Lattice(**pals_dict)
35+
36+
def to_file(self, filename: str):
37+
"""Save a Lattice to a text file"""
38+
pals_dict = self.model_dump()
39+
store_dict_to_file(filename, pals_dict)

src/pals/kinds/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .ACKicker import ACKicker # noqa: F401
66
from .BeamBeam import BeamBeam # noqa: F401
77
from .BeamLine import BeamLine # noqa: F401
8+
from .Lattice import Lattice # noqa: F401
89
from .BeginningEle import BeginningEle # noqa: F401
910
from .Converter import Converter # noqa: F401
1011
from .CrabCavity import CrabCavity # noqa: F401
@@ -39,3 +40,4 @@
3940
# Rebuild pydantic models that depend on other classes
4041
UnionEle.model_rebuild()
4142
BeamLine.model_rebuild()
43+
Lattice.model_rebuild()

src/pals/kinds/all_elements.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
def get_all_element_types(extra_types: tuple = None):
4444
"""Return a tuple of all element types that can be used in BeamLine or UnionEle."""
4545
element_types = (
46+
"Lattice", # Forward reference to handle circular import
4647
"BeamLine", # Forward reference to handle circular import
4748
"UnionEle", # Forward reference to handle circular import
4849
ACKicker,

tests/test_elements.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,3 +511,29 @@ def test_UnionEle():
511511
assert len(element_with_children.elements) == 2
512512
assert element_with_children.elements[0].name == "m1"
513513
assert element_with_children.elements[1].name == "d1"
514+
515+
516+
def test_Lattice():
517+
# Create first line with one base element
518+
element1 = pals.Marker(name="element1")
519+
line1 = pals.BeamLine(name="line1", line=[element1])
520+
assert line1.line == [element1]
521+
# Extend first line with one thick element
522+
element2 = pals.Drift(name="element2", length=2.0)
523+
line1.line.extend([element2])
524+
assert line1.line == [element1, element2]
525+
# Create second line with one drift element
526+
element3 = pals.Drift(name="element3", length=3.0)
527+
line2 = pals.BeamLine(name="line2", line=[element3])
528+
# Extend first line with second line
529+
line1.line.extend(line2.line)
530+
assert line1.line == [element1, element2, element3]
531+
# Build lattice with two lines
532+
lattice = pals.Lattice(name="lattice", branches=[line1, line2])
533+
assert lattice.name == "lattice"
534+
assert lattice.kind == "Lattice"
535+
assert len(lattice.branches) == 2
536+
assert lattice.branches[0].name == "line1"
537+
assert lattice.branches[1].name == "line2"
538+
assert lattice.branches[0].line == [element1, element2, element3]
539+
assert lattice.branches[1].line == [element3]

tests/test_serialization.py

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -171,40 +171,45 @@ def test_comprehensive_lattice():
171171
wiggler = pals.Wiggler(name="wiggler1", length=2.0)
172172

173173
# Create comprehensive lattice
174-
lattice = pals.BeamLine(
174+
lattice = pals.Lattice(
175175
name="comprehensive_lattice",
176-
line=[
177-
beginning, # Start with beginning element
178-
fiducial, # Global coordinate reference
179-
marker, # Mark position
180-
drift, # Field-free region
181-
quadrupole, # Focusing element
182-
sextupole, # Chromatic correction
183-
octupole, # Higher order correction
184-
multipole, # General multipole
185-
rbend, # Rectangular bend
186-
sbend, # Sector bend
187-
solenoid, # Longitudinal focusing
188-
rfcavity, # RF acceleration
189-
crabcavity, # RF crab cavity
190-
kicker, # Transverse kick
191-
ackicker, # AC kicker
192-
patch, # Coordinate transformation
193-
floorshift, # Global coordinate shift
194-
instrument, # Measurement device
195-
mask, # Collimation
196-
match, # Matching element
197-
egun, # Electron source
198-
converter, # Species conversion
199-
foil, # Electron stripping
200-
beambeam, # Colliding beams
201-
feedback, # Feedback system
202-
girder, # Support structure
203-
fork, # Branch connection
204-
taylor, # Taylor map
205-
unionele, # Overlapping elements
206-
wiggler, # Undulator
207-
nullele, # Placeholder
176+
branches=[
177+
pals.BeamLine(
178+
name="comprehensive_beamline",
179+
line=[
180+
beginning, # Start with beginning element
181+
fiducial, # Global coordinate reference
182+
marker, # Mark position
183+
drift, # Field-free region
184+
quadrupole, # Focusing element
185+
sextupole, # Chromatic correction
186+
octupole, # Higher order correction
187+
multipole, # General multipole
188+
rbend, # Rectangular bend
189+
sbend, # Sector bend
190+
solenoid, # Longitudinal focusing
191+
rfcavity, # RF acceleration
192+
crabcavity, # RF crab cavity
193+
kicker, # Transverse kick
194+
ackicker, # AC kicker
195+
patch, # Coordinate transformation
196+
floorshift, # Global coordinate shift
197+
instrument, # Measurement device
198+
mask, # Collimation
199+
match, # Matching element
200+
egun, # Electron source
201+
converter, # Species conversion
202+
foil, # Electron stripping
203+
beambeam, # Colliding beams
204+
feedback, # Feedback system
205+
girder, # Support structure
206+
fork, # Branch connection
207+
taylor, # Taylor map
208+
unionele, # Overlapping elements
209+
wiggler, # Undulator
210+
nullele, # Placeholder
211+
],
212+
)
208213
],
209214
)
210215

@@ -217,10 +222,13 @@ def test_comprehensive_lattice():
217222
print(f"\nComprehensive lattice YAML:\n{file.read()}")
218223

219224
# Deserialize back to Python object using Pydantic model logic
220-
loaded_lattice = pals.BeamLine.from_file(yaml_file)
225+
loaded_lattice = pals.Lattice.from_file(yaml_file)
221226

222227
# Verify the loaded lattice has the correct structure and parameter groups
223-
assert len(loaded_lattice.line) == 31 # Should have 31 elements
228+
assert len(loaded_lattice.branches) == 1 # Should have 1 branch
229+
assert (
230+
len(loaded_lattice.branches[0].line) == 31
231+
) # Should have 31 elements in the branch
224232

225233
# Verify specific elements with parameter groups are correctly loaded
226234
sextupole_loaded = None
@@ -229,7 +237,7 @@ def test_comprehensive_lattice():
229237
rfcavity_loaded = None
230238
unionele_loaded = None
231239

232-
for elem in loaded_lattice.line:
240+
for elem in loaded_lattice.branches[0].line:
233241
if elem.name == "sextupole1":
234242
sextupole_loaded = elem
235243
elif elem.name == "octupole1":
@@ -272,10 +280,13 @@ def test_comprehensive_lattice():
272280
print(f"\nComprehensive lattice JSON:\n{file.read()}")
273281

274282
# Deserialize back to Python object using Pydantic model logic
275-
loaded_lattice_json = pals.BeamLine.from_file(json_file)
283+
loaded_lattice_json = pals.Lattice.from_file(json_file)
276284

277285
# Verify the loaded lattice has the correct structure and parameter groups
278-
assert len(loaded_lattice_json.line) == 31 # Should have 31 elements
286+
assert len(loaded_lattice_json.branches) == 1 # Should have 1 branch
287+
assert (
288+
len(loaded_lattice_json.branches[0].line) == 31
289+
) # Should have 31 elements in the branch
279290

280291
# Verify specific elements with parameter groups are correctly loaded
281292
sextupole_loaded_json = None
@@ -284,7 +295,7 @@ def test_comprehensive_lattice():
284295
rfcavity_loaded_json = None
285296
unionele_loaded_json = None
286297

287-
for elem in loaded_lattice_json.line:
298+
for elem in loaded_lattice_json.branches[0].line:
288299
if elem.name == "sextupole1":
289300
sextupole_loaded_json = elem
290301
elif elem.name == "octupole1":

0 commit comments

Comments
 (0)