@@ -335,8 +335,29 @@ def test_checkpoint_timesteps_validation(self):
335335 def test_checkpoint_file_validation (self ):
336336 """Test validation of checkpoint file path."""
337337 options = self .default_options
338- with self .assertRaises (ValueError ):
338+
339+ # Test with non-existent directory
340+ with self .assertRaises (FileNotFoundError ) as cm :
339341 options .checkpoint_file = "invalid/path/checkpoint.h5"
342+ self .assertEqual (str (cm .exception ), "Checkpoint folder invalid/path does not exist." )
343+
344+ # Test with temporary directory
345+ with TemporaryDirectory () as temp_dir :
346+ # Test invalid file extension
347+ invalid_file = Path (temp_dir ) / "checkpoint.txt"
348+ with self .assertRaises (ValueError ) as cm :
349+ options .checkpoint_file = invalid_file
350+ self .assertEqual (str (cm .exception ), f"Checkpoint file { invalid_file } must have .h5 extension." )
351+
352+ # Test valid file path
353+ valid_file = Path (temp_dir ) / "checkpoint.h5"
354+ options .checkpoint_file = valid_file
355+ self .assertEqual (options .checkpoint_file , valid_file )
356+
357+ # Test invalid type
358+ with self .assertRaises (ValueError ) as cm :
359+ options .checkpoint_file = 123
360+ self .assertEqual (str (cm .exception ), "Checkpoint file must be a string or Path object." )
340361
341362 def test_checkpoint_file_required_when_parameters_set (self ):
342363 """Test that checkpoint file is required when checkpoint parameters are set."""
0 commit comments