@@ -61,17 +61,19 @@ def test_warm_start(self, tmp_path: Path):
6161 with tempfile .TemporaryDirectory () as temp_dir :
6262 # config for two steps model
6363 gpt2_8_steps_config_file_path = working_dir / "gpt2_train_num_steps_8.yaml"
64- gpt2_8_steps_config_dict = load_app_config_dict (
64+ gpt2_8_steps_config_dict : dict = load_app_config_dict (
6565 gpt2_8_steps_config_file_path , experiment_id = "0" , experiments_root_path = tmp_path
6666 )
6767
6868 # adopt the checkpoint path
69- checkpoint_path = temp_dir
69+ experiment_dir_0 = Path (temp_dir ) / "0"
70+ checkpoint_dir_path_0 = experiment_dir_0 / "checkpoints"
71+ experiment_dir_1 = Path (temp_dir ) / "1"
72+ checkpoint_dir_path_1 = experiment_dir_1 / "checkpoints"
7073 gpt2_8_steps_config_dict ["checkpoint_saving" ]["config" ]["checkpoint_saving_execution" ]["config" ][
7174 "checkpoint_path"
72- ] = checkpoint_path
73- gpt2_8_steps_config_dict ["settings" ]["paths" ]["checkpoint_saving_path" ] = checkpoint_path
74- loss_values_experiment_0_path = checkpoint_path + "/experiment_0_loss_scores.txt"
75+ ] = checkpoint_dir_path_0
76+ loss_values_experiment_0_path = experiment_dir_0 / "experiment_0_loss_scores.txt"
7577
7678 # config for one step model
7779 gpt2_warm_start_after_4_steps_config_file_path = working_dir / "gpt2_warm_start_from_step_4.yaml"
@@ -81,17 +83,17 @@ def test_warm_start(self, tmp_path: Path):
8183
8284 # adopt the checkpoint path
8385 gpt2_warm_start_after_4_steps_dict ["wrapped_model" ]["config" ]["checkpoint_path" ] = (
84- checkpoint_path + "/0/eid_0-model-seen_steps_4-seen_tokens_2048-target_steps_15-target_tokens_7680.bin"
86+ checkpoint_dir_path_0
87+ / "eid_0-model-seen_steps_4-seen_tokens_2048-target_steps_15-target_tokens_7680.bin"
8588 )
8689 gpt2_warm_start_after_4_steps_dict ["optimizer" ]["config" ]["checkpoint_path" ] = (
87- checkpoint_path
88- + "/0/ eid_0-optimizer-seen_steps_4-seen_tokens_2048-target_steps_15-target_tokens_7680.bin"
90+ checkpoint_dir_path_0
91+ / " eid_0-optimizer-seen_steps_4-seen_tokens_2048-target_steps_15-target_tokens_7680.bin"
8992 )
9093 gpt2_warm_start_after_4_steps_dict ["checkpoint_saving" ]["config" ]["checkpoint_saving_execution" ]["config" ][
9194 "checkpoint_path"
92- ] = checkpoint_path
93- gpt2_warm_start_after_4_steps_dict ["settings" ]["paths" ]["checkpoint_saving_path" ] = checkpoint_path
94- loss_values_experiment_1_path = checkpoint_path + "/experiment_1_loss_scores.txt"
95+ ] = checkpoint_dir_path_1
96+ loss_values_experiment_1_path = experiment_dir_1 / "experiment_1_loss_scores.txt"
9597
9698 # # adopt dataset path
9799 # gpt2_warm_start_after_4_steps_dict["train_dataset"]["config"]["raw_data_path"] = (
@@ -121,22 +123,25 @@ def test_warm_start(self, tmp_path: Path):
121123 json .dump (loss_scores_0 , f )
122124
123125 # make sure that the checkpoints have been written and checkpoint info file has been updated
124- checkpoint_info_file_path = Path (checkpoint_path ) / "0/last_checkpoint_info.json"
125- assert checkpoint_info_file_path .exists ()
126- with open (checkpoint_info_file_path , "r" ) as f :
126+ checkpoint_info_file_path_0 = Path (checkpoint_dir_path_0 ) / "last_checkpoint_info.json"
127+ print (list (Path (checkpoint_dir_path_0 ).glob ("**/last_checkpoint_info.json" )))
128+ assert checkpoint_info_file_path_0 .exists ()
129+ with open (checkpoint_info_file_path_0 , "r" ) as f :
127130 checkpoint_info = json .load (f )
128- assert checkpoint_info ["model_checkpoint_path" ] == (
129- checkpoint_path
130- + "/0/eid_0-model-seen_steps_12-seen_tokens_6144-target_steps_15-target_tokens_7680.bin"
131+ assert (
132+ Path (checkpoint_info ["model_checkpoint_path" ])
133+ == checkpoint_dir_path_0
134+ / "eid_0-model-seen_steps_12-seen_tokens_6144-target_steps_15-target_tokens_7680.bin"
131135 )
132- assert checkpoint_info ["optimizer_checkpoint_path" ] == (
133- checkpoint_path
134- + "/0/eid_0-optimizer-seen_steps_12-seen_tokens_6144-target_steps_15-target_tokens_7680.bin"
136+ assert (
137+ Path (checkpoint_info ["optimizer_checkpoint_path" ])
138+ == checkpoint_dir_path_0
139+ / "eid_0-optimizer-seen_steps_12-seen_tokens_6144-target_steps_15-target_tokens_7680.bin"
135140 )
136141 assert Path (checkpoint_info ["model_checkpoint_path" ]).exists ()
137142 assert Path (checkpoint_info ["optimizer_checkpoint_path" ]).exists ()
138143
139- checkpoint_paths = list (Path (checkpoint_path ).glob ("**/*.bin" ))
144+ checkpoint_paths = list (Path (checkpoint_dir_path_0 ).glob ("**/*.bin" ))
140145 model_max_seen_steps = - 1
141146 model_max_seen_tokens = - 1
142147 optimizer_max_seen_steps = - 1
0 commit comments