Skip to content

Commit df48759

Browse files
precicelywaltsims
andauthored
#592: Add checkpoint flags and some test logic. (#599)
* #592: added checkpoint flags and some test logic. * #592: added minimal example of how to use checkpointing. * #592: added windows guard in tests. * #592: improved test coverage. * #592: moved checkpoint validation to instantiation. Added setters/getters. * #592: added additional "valid" test for checkpointing. * #592: tweaked tests. * #592: addressed gap in testing coverage. --------- Co-authored-by: Walter Simson <walter.a.simson@gmail.com>
1 parent 4435c27 commit df48759

4 files changed

Lines changed: 288 additions & 0 deletions

File tree

examples/checkpointing/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Checkpointing
2+
3+
This example demonstrates how to checkpoint simulations.
4+
5+
It uses the same simulation parameters as found in "3D FFT Reconstruction For A Planar Sensor".
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from copy import copy
2+
from pathlib import Path
3+
from tempfile import TemporaryDirectory
4+
5+
import numpy as np
6+
7+
from kwave.data import Vector
8+
from kwave.kgrid import kWaveGrid
9+
from kwave.kmedium import kWaveMedium
10+
from kwave.ksensor import kSensor
11+
from kwave.ksource import kSource
12+
from kwave.kspaceFirstOrder3D import kspaceFirstOrder3D
13+
from kwave.options.simulation_execution_options import SimulationExecutionOptions
14+
from kwave.options.simulation_options import SimulationOptions
15+
from kwave.utils.filters import smooth
16+
from kwave.utils.mapgen import make_ball
17+
18+
# This script demonstrates how to use the checkpointing feature of k-Wave.
19+
# It runs the same simulation twice, with the second run starting from a
20+
# checkpoint file created during the first run.
21+
# This implementation checkpoints based on the number of completed timesteps.
22+
# Alternatively, you can checkpoint based on the simulation time. Below is a
23+
# snippet demonstrating how:
24+
#
25+
# checkpoint_interval = 1 # 1 second
26+
# execution_options = SimulationExecutionOptions(
27+
# checkpoint_file=checkpoint_filename,
28+
# checkpoint_interval=checkpoint_interval
29+
# )
30+
# Note: checkpoint timesteps and checkpoint interval must be integers.
31+
32+
33+
def make_simulation_parameters(directory: Path, checkpoint_timesteps: int):
34+
"""
35+
See the 3D FFT Reconstruction For A Planar Sensor example for context.
36+
"""
37+
scale = 1
38+
39+
# create the computational grid
40+
PML_size = 10 # size of the PML in grid points
41+
N = Vector([32, 64, 64]) * scale - 2 * PML_size # number of grid points
42+
d = Vector([0.2e-3, 0.2e-3, 0.2e-3]) / scale # grid point spacing [m]
43+
kgrid = kWaveGrid(N, d)
44+
45+
# define the properties of the propagation medium
46+
medium = kWaveMedium(sound_speed=1500) # [m/s]
47+
48+
# create initial pressure distribution using makeBall
49+
ball_magnitude = 10 # [Pa]
50+
ball_radius = 3 * scale # [grid points]
51+
p0 = ball_magnitude * make_ball(N, N / 2, ball_radius)
52+
p0 = smooth(p0, restore_max=True)
53+
54+
source = kSource()
55+
source.p0 = p0
56+
57+
# define a binary planar sensor
58+
sensor = kSensor()
59+
sensor_mask = np.zeros(N)
60+
sensor_mask[0] = 1
61+
sensor.mask = sensor_mask
62+
63+
input_filename = directory / "kwave_input.h5"
64+
output_filename = directory / "kwave_output.h5"
65+
checkpoint_filename = directory / "kwave_checkpoint.h5"
66+
67+
# set the input arguments
68+
simulation_options = SimulationOptions(
69+
save_to_disk=True,
70+
pml_size=PML_size,
71+
pml_inside=False,
72+
smooth_p0=False,
73+
data_cast="single",
74+
input_filename=input_filename,
75+
output_filename=output_filename,
76+
)
77+
78+
execution_options = SimulationExecutionOptions(checkpoint_file=checkpoint_filename, checkpoint_timesteps=checkpoint_timesteps)
79+
return kgrid, medium, source, sensor, simulation_options, execution_options
80+
81+
82+
def main():
83+
# This simulation has 212 timesteps, therefore, the checkpoint is taken at
84+
# the halfway point.
85+
checkpoint_timesteps = 106
86+
87+
with TemporaryDirectory() as tmpdir:
88+
tmpdir = Path(tmpdir)
89+
# create the simulation parameters for the 1st run
90+
kgrid, medium, source, sensor, simulation_options, execution_options = make_simulation_parameters(
91+
directory=tmpdir, checkpoint_timesteps=checkpoint_timesteps
92+
)
93+
kspaceFirstOrder3D(kgrid, source, sensor, medium, simulation_options, execution_options)
94+
print("Temporary directory contains 3 files (including checkpoint):")
95+
print("\t-", "\n\t- ".join([f.name for f in tmpdir.glob("*.h5")]))
96+
97+
# create the simulation parameters for the 2nd run
98+
kgrid, medium, source, sensor, simulation_options, execution_options = make_simulation_parameters(
99+
directory=tmpdir, checkpoint_timesteps=checkpoint_timesteps
100+
)
101+
kspaceFirstOrder3D(kgrid, source, sensor, medium, simulation_options, execution_options)
102+
print("Temporary directory contains 2 files (checkpoint has been deleted):")
103+
print("\t-", "\n\t- ".join([f.name for f in tmpdir.glob("*.h5")]))
104+
105+
106+
if __name__ == "__main__":
107+
main()

kwave/options/simulation_execution_options.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import os
22
import warnings
3+
from logging import getLogger
34
from pathlib import Path
45
from typing import Optional, Union
56

67
from kwave import BINARY_DIR, PLATFORM
78
from kwave.ksensor import kSensor
89

10+
logger = getLogger(__name__)
11+
912

1013
class SimulationExecutionOptions:
1114
"""
@@ -27,6 +30,9 @@ def __init__(
2730
verbose_level: int = 0,
2831
auto_chunking: Optional[bool] = True,
2932
show_sim_log: bool = True,
33+
checkpoint_interval: Optional[int] = None, # [seconds]
34+
checkpoint_timesteps: Optional[int] = None, # [timestep integer]
35+
checkpoint_file: Optional[Path | str] = None, # [path to hdf5 file]
3036
):
3137
self.is_gpu_simulation = is_gpu_simulation
3238
self._binary_path = binary_path
@@ -41,6 +47,13 @@ def __init__(
4147
self.verbose_level = verbose_level
4248
self.auto_chunking = auto_chunking
4349
self.show_sim_log = show_sim_log
50+
self.checkpoint_interval = checkpoint_interval
51+
self.checkpoint_timesteps = checkpoint_timesteps
52+
self.checkpoint_file = checkpoint_file
53+
54+
if self.checkpoint_file is not None:
55+
if self.checkpoint_interval is None and self.checkpoint_timesteps is None:
56+
raise ValueError("One of checkpoint_interval or checkpoint_timesteps must be set when checkpoint_file is set.")
4457

4558
@property
4659
def num_threads(self) -> Union[int, str]:
@@ -164,6 +177,47 @@ def device_num(self, value: Optional[int]):
164177
raise ValueError("Device number must be non-negative")
165178
self._device_num = value
166179

180+
@property
181+
def checkpoint_interval(self) -> Optional[int]:
182+
return self._checkpoint_interval
183+
184+
@checkpoint_interval.setter
185+
def checkpoint_interval(self, value: Optional[int]):
186+
if value is not None:
187+
if not isinstance(value, int) or value < 0:
188+
raise ValueError("Checkpoint interval must be a positive integer")
189+
self._checkpoint_interval = value
190+
191+
@property
192+
def checkpoint_timesteps(self) -> Optional[int]:
193+
return self._checkpoint_timesteps
194+
195+
@checkpoint_timesteps.setter
196+
def checkpoint_timesteps(self, value: Optional[int]):
197+
if value is not None:
198+
if not isinstance(value, int) or value < 0:
199+
raise ValueError("Checkpoint timesteps must be a positive integer")
200+
self._checkpoint_timesteps = value
201+
202+
@property
203+
def checkpoint_file(self) -> Optional[Path]:
204+
if self._checkpoint_file is None:
205+
return None
206+
return self._checkpoint_file
207+
208+
@checkpoint_file.setter
209+
def checkpoint_file(self, value: Optional[Path | str]):
210+
if value is not None:
211+
if not isinstance(value, (str, Path)):
212+
raise ValueError("Checkpoint file must be a string or Path object.")
213+
if isinstance(value, str):
214+
value = Path(value)
215+
if not value.parent.is_dir():
216+
raise FileNotFoundError(f"Checkpoint folder {value.parent} does not exist.")
217+
if value.suffix != ".h5":
218+
raise ValueError(f"Checkpoint file {value} must have .h5 extension.")
219+
self._checkpoint_file = value
220+
167221
def as_list(self, sensor: kSensor) -> list[str]:
168222
options_list = []
169223

@@ -179,6 +233,17 @@ def as_list(self, sensor: kSensor) -> list[str]:
179233
options_list.append("--verbose")
180234
options_list.append(str(self.verbose_level))
181235

236+
if (self.checkpoint_interval is not None or self.checkpoint_timesteps is not None) and self.checkpoint_file is not None:
237+
if self.checkpoint_timesteps is not None:
238+
options_list.append("--checkpoint_timesteps")
239+
options_list.append(str(self.checkpoint_timesteps))
240+
if self.checkpoint_interval is not None:
241+
options_list.append("--checkpoint_interval")
242+
options_list.append(str(self.checkpoint_interval))
243+
244+
options_list.append("--checkpoint_file")
245+
options_list.append(str(self.checkpoint_file))
246+
182247
record_options_map = {
183248
"p": "p_raw",
184249
"p_max": "p_max",

tests/test_simulation_execution_options.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22
import unittest
3+
from pathlib import Path
4+
from tempfile import TemporaryDirectory
35
from unittest.mock import Mock, patch
46

57
from kwave import PLATFORM
@@ -31,6 +33,8 @@ def test_default_initialization(self):
3133
self.assertEqual(options.verbose_level, 0)
3234
self.assertTrue(options.auto_chunking)
3335
self.assertTrue(options.show_sim_log)
36+
self.assertIsNone(options.checkpoint_interval)
37+
self.assertIsNone(options.checkpoint_file)
3438

3539
def test_num_threads_setter_valid(self):
3640
"""Test setting a valid number of threads."""
@@ -214,6 +218,113 @@ def test_list_compared_to_string(self):
214218
options_string = options.get_options_string(self.mock_sensor)
215219
self.assertEqual(" ".join(options_list), options_string)
216220

221+
def test_as_list_with_valid_checkpoint_interval_options(self):
222+
"""Test the checkpoint interval options."""
223+
options = self.default_options
224+
options.num_threads = 1
225+
options.checkpoint_interval = 10
226+
options.checkpoint_file = "checkpoint.h5"
227+
228+
options_list = options.as_list(self.mock_sensor)
229+
expected_elements = [
230+
"--checkpoint_interval",
231+
"10",
232+
"--checkpoint_file",
233+
"checkpoint.h5",
234+
"--p_raw",
235+
"--u_max",
236+
"-s",
237+
"10",
238+
]
239+
if not PLATFORM == "windows":
240+
expected_elements.insert(0, "-t")
241+
expected_elements.insert(1, f"1")
242+
self.assertListEqual(expected_elements, options_list)
243+
244+
def test_as_list_with_valid_checkpoint_timesteps_options(self):
245+
"""Test the checkpoint interval options."""
246+
options = self.default_options
247+
options.num_threads = 1
248+
options.checkpoint_timesteps = 10
249+
options.checkpoint_file = "checkpoint.h5"
250+
251+
options_list = options.as_list(self.mock_sensor)
252+
expected_elements = [
253+
"--checkpoint_timesteps",
254+
"10",
255+
"--checkpoint_file",
256+
"checkpoint.h5",
257+
"--p_raw",
258+
"--u_max",
259+
"-s",
260+
"10",
261+
]
262+
if not PLATFORM == "windows":
263+
expected_elements.insert(0, "-t")
264+
expected_elements.insert(1, f"1")
265+
self.assertListEqual(expected_elements, options_list)
266+
267+
def test_initialization_with_invalid_checkpoint_options(self):
268+
"""Test initialization with invalid checkpoint options."""
269+
270+
# Test with valid checkpoint_file but unset checkpoint_timesteps and checkpoint_interval
271+
with TemporaryDirectory() as temp_dir:
272+
checkpoint_file = Path(temp_dir) / "checkpoint.h5"
273+
with self.assertRaises(ValueError):
274+
SimulationExecutionOptions(checkpoint_file=checkpoint_file)
275+
276+
# Test with invalid checkpoint_file type
277+
with self.assertRaises(ValueError):
278+
SimulationExecutionOptions(checkpoint_file=12345, checkpoint_timesteps=10)
279+
280+
# Test with invalid checkpoint_timesteps
281+
with self.assertRaises(ValueError):
282+
SimulationExecutionOptions(checkpoint_timesteps=-1, checkpoint_file="checkpoint.h5")
283+
284+
with self.assertRaises(ValueError):
285+
SimulationExecutionOptions(checkpoint_timesteps="not_an_integer", checkpoint_file="checkpoint.h5")
286+
287+
# Test with invalid checkpoint_interval
288+
with self.assertRaises(ValueError):
289+
SimulationExecutionOptions(checkpoint_interval=-1, checkpoint_file="checkpoint.h5")
290+
291+
with self.assertRaises(ValueError):
292+
SimulationExecutionOptions(checkpoint_interval="not_an_integer", checkpoint_file="checkpoint.h5")
293+
294+
# Test with invalid checkpoint_file
295+
with self.assertRaises(ValueError):
296+
SimulationExecutionOptions(
297+
checkpoint_interval=10,
298+
checkpoint_file="checkpoint.txt", # Wrong extension
299+
)
300+
301+
with self.assertRaises(FileNotFoundError):
302+
SimulationExecutionOptions(
303+
checkpoint_interval=10,
304+
checkpoint_file="nonexistent_dir/checkpoint.h5", # Non-existent directory
305+
)
306+
307+
def test_initialization_with_valid_checkpoint_options(self):
308+
"""Test initialization with valid checkpoint options."""
309+
with TemporaryDirectory() as temp_dir:
310+
checkpoint_file = Path(temp_dir) / "checkpoint.h5"
311+
312+
# Test with valid checkpoint_timesteps
313+
options = SimulationExecutionOptions(checkpoint_timesteps=10, checkpoint_file=checkpoint_file)
314+
self.assertEqual(options.checkpoint_timesteps, 10)
315+
self.assertEqual(options.checkpoint_file, checkpoint_file)
316+
317+
# Test with valid checkpoint_interval
318+
options = SimulationExecutionOptions(checkpoint_interval=20, checkpoint_file=checkpoint_file)
319+
self.assertEqual(options.checkpoint_interval, 20)
320+
self.assertEqual(options.checkpoint_file, checkpoint_file)
321+
322+
# Test with both checkpoint options - should be valid
323+
options = SimulationExecutionOptions(checkpoint_timesteps=10, checkpoint_interval=20, checkpoint_file=checkpoint_file)
324+
self.assertEqual(options.checkpoint_timesteps, 10)
325+
self.assertEqual(options.checkpoint_interval, 20)
326+
self.assertEqual(options.checkpoint_file, checkpoint_file)
327+
217328

218329
if __name__ == "__main__":
219330
unittest.main()

0 commit comments

Comments
 (0)