Skip to content

Commit c3cf4f5

Browse files
committed
refactor: fixed torchrun warmstart test
1 parent 3434e93 commit c3cf4f5

1 file changed

Lines changed: 26 additions & 21 deletions

File tree

tests/end2end_tests/test_fsdp_warmstart.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,19 @@ def test_warm_start(self, tmp_path: Path):
6161
with tempfile.TemporaryDirectory() as temp_dir:
6262
# config for two steps model
6363
gpt2_8_steps_config_file_path = working_dir / "gpt2_train_num_steps_8.yaml"
64-
gpt2_8_steps_config_dict = load_app_config_dict(
64+
gpt2_8_steps_config_dict: dict = load_app_config_dict(
6565
gpt2_8_steps_config_file_path, experiment_id="0", experiments_root_path=tmp_path
6666
)
6767

6868
# adopt the checkpoint path
69-
checkpoint_path = temp_dir
69+
experiment_dir_0 = Path(temp_dir) / "0"
70+
checkpoint_dir_path_0 = experiment_dir_0 / "checkpoints"
71+
experiment_dir_1 = Path(temp_dir) / "1"
72+
checkpoint_dir_path_1 = experiment_dir_1 / "checkpoints"
7073
gpt2_8_steps_config_dict["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"][
7174
"checkpoint_path"
72-
] = checkpoint_path
73-
gpt2_8_steps_config_dict["settings"]["paths"]["checkpoint_saving_path"] = checkpoint_path
74-
loss_values_experiment_0_path = checkpoint_path + "/experiment_0_loss_scores.txt"
75+
] = checkpoint_dir_path_0
76+
loss_values_experiment_0_path = experiment_dir_0 / "experiment_0_loss_scores.txt"
7577

7678
# config for one step model
7779
gpt2_warm_start_after_4_steps_config_file_path = working_dir / "gpt2_warm_start_from_step_4.yaml"
@@ -81,17 +83,17 @@ def test_warm_start(self, tmp_path: Path):
8183

8284
# adopt the checkpoint path
8385
gpt2_warm_start_after_4_steps_dict["wrapped_model"]["config"]["checkpoint_path"] = (
84-
checkpoint_path + "/0/eid_0-model-seen_steps_4-seen_tokens_2048-target_steps_15-target_tokens_7680.bin"
86+
checkpoint_dir_path_0
87+
/ "eid_0-model-seen_steps_4-seen_tokens_2048-target_steps_15-target_tokens_7680.bin"
8588
)
8689
gpt2_warm_start_after_4_steps_dict["optimizer"]["config"]["checkpoint_path"] = (
87-
checkpoint_path
88-
+ "/0/eid_0-optimizer-seen_steps_4-seen_tokens_2048-target_steps_15-target_tokens_7680.bin"
90+
checkpoint_dir_path_0
91+
/ "eid_0-optimizer-seen_steps_4-seen_tokens_2048-target_steps_15-target_tokens_7680.bin"
8992
)
9093
gpt2_warm_start_after_4_steps_dict["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"][
9194
"checkpoint_path"
92-
] = checkpoint_path
93-
gpt2_warm_start_after_4_steps_dict["settings"]["paths"]["checkpoint_saving_path"] = checkpoint_path
94-
loss_values_experiment_1_path = checkpoint_path + "/experiment_1_loss_scores.txt"
95+
] = checkpoint_dir_path_1
96+
loss_values_experiment_1_path = experiment_dir_1 / "experiment_1_loss_scores.txt"
9597

9698
# # adopt dataset path
9799
# gpt2_warm_start_after_4_steps_dict["train_dataset"]["config"]["raw_data_path"] = (
@@ -121,22 +123,25 @@ def test_warm_start(self, tmp_path: Path):
121123
json.dump(loss_scores_0, f)
122124

123125
# make sure that the checkpoints have been written and checkpoint info file has been updated
124-
checkpoint_info_file_path = Path(checkpoint_path) / "0/last_checkpoint_info.json"
125-
assert checkpoint_info_file_path.exists()
126-
with open(checkpoint_info_file_path, "r") as f:
126+
checkpoint_info_file_path_0 = Path(checkpoint_dir_path_0) / "last_checkpoint_info.json"
127+
print(list(Path(checkpoint_dir_path_0).glob("**/last_checkpoint_info.json")))
128+
assert checkpoint_info_file_path_0.exists()
129+
with open(checkpoint_info_file_path_0, "r") as f:
127130
checkpoint_info = json.load(f)
128-
assert checkpoint_info["model_checkpoint_path"] == (
129-
checkpoint_path
130-
+ "/0/eid_0-model-seen_steps_12-seen_tokens_6144-target_steps_15-target_tokens_7680.bin"
131+
assert (
132+
Path(checkpoint_info["model_checkpoint_path"])
133+
== checkpoint_dir_path_0
134+
/ "eid_0-model-seen_steps_12-seen_tokens_6144-target_steps_15-target_tokens_7680.bin"
131135
)
132-
assert checkpoint_info["optimizer_checkpoint_path"] == (
133-
checkpoint_path
134-
+ "/0/eid_0-optimizer-seen_steps_12-seen_tokens_6144-target_steps_15-target_tokens_7680.bin"
136+
assert (
137+
Path(checkpoint_info["optimizer_checkpoint_path"])
138+
== checkpoint_dir_path_0
139+
/ "eid_0-optimizer-seen_steps_12-seen_tokens_6144-target_steps_15-target_tokens_7680.bin"
135140
)
136141
assert Path(checkpoint_info["model_checkpoint_path"]).exists()
137142
assert Path(checkpoint_info["optimizer_checkpoint_path"]).exists()
138143

139-
checkpoint_paths = list(Path(checkpoint_path).glob("**/*.bin"))
144+
checkpoint_paths = list(Path(checkpoint_dir_path_0).glob("**/*.bin"))
140145
model_max_seen_steps = -1
141146
model_max_seen_tokens = -1
142147
optimizer_max_seen_steps = -1

0 commit comments

Comments
 (0)