1+ import copy
12import csv
23import logging
34import os
@@ -60,6 +61,7 @@ def __init__(
6061 self .metrics = {}
6162 self .epoch_time = None
6263 self .best_val_metric = 1e10
64+ self .best_model_state = None
6365
6466 self .evaluator = Evaluator ()
6567
@@ -218,6 +220,44 @@ def validate(self):
218220 def predict (self ):
219221 """Implemented by derived classes."""
220222
223+ def update_best_model (self , val_metrics ):
224+ """Updates the best val metric and model, saves the best model, and saves the best model predictions"""
225+ self .best_val_metric = val_metrics [type (self .loss_fn ).__name__ ]["metric" ]
226+ self .best_model_state = copy .deepcopy (self .model .state_dict ())
227+
228+ self .save_model ("best_checkpoint.pt" , val_metrics , False )
229+
230+ logging .debug (
231+ f"Saving prediction results for epoch { self .epoch } to: /results/{ self .timestamp_id } /"
232+ )
233+ self .predict (self .train_loader , "train" )
234+ self .predict (self .val_loader , "val" )
235+ self .predict (self .test_loader , "test" )
236+
237+ def save_model (self , checkpoint_file , val_metrics = None , training_state = True ):
238+ """Saves the model state dict"""
239+
240+ if training_state :
241+ state = {
242+ "epoch" : self .epoch ,
243+ "step" : self .step ,
244+ "state_dict" : self .model .state_dict (),
245+ "optimizer" : self .optimizer .state_dict (),
246+ "scheduler" : self .scheduler .scheduler .state_dict (),
247+ "best_val_metric" : self .best_val_metric ,
248+ }
249+ else :
250+ state = {"state_dict" : self .model .state_dict (), "val_metrics" : val_metrics }
251+
252+ checkpoint_dir = os .path .join (
253+ self .run_dir , "results" , self .timestamp_id , "checkpoint"
254+ )
255+ os .makedirs (checkpoint_dir , exist_ok = True )
256+ filename = os .path .join (checkpoint_dir , checkpoint_file )
257+
258+ torch .save (state , filename )
259+ return filename
260+
221261 def save_results (self , output , filename , node_level_predictions = False ):
222262 results_path = os .path .join (self .run_dir , "results" , self .timestamp_id )
223263 os .makedirs (results_path , exist_ok = True )
@@ -237,3 +277,9 @@ def save_results(self, output, filename, node_level_predictions=False):
237277 csvwriter .writerow (headers )
238278 elif i > 0 :
239279 csvwriter .writerow (output [i - 1 , :])
280+ return filename
281+
282+ def load_checkpoint (self ):
283+ """Loads the model from a checkpoint.pt file"""
284+ # TODO: implement this method
285+ pass
0 commit comments