1+ import copy
2+ import csv
13import logging
4+ import os
25from abc import ABC , abstractmethod
6+ from datetime import datetime
37
48import torch
59import torch .optim as optim
1721from matdeeplearn .common .registry import registry
1822from matdeeplearn .models .base_model import BaseModel
1923from matdeeplearn .modules .evaluator import Evaluator
20- from matdeeplearn .modules .loss import *
2124from matdeeplearn .modules .scheduler import LRScheduler
2225
2326
@@ -35,6 +38,7 @@ def __init__(
3538 test_loader : DataLoader ,
3639 loss : nn .Module ,
3740 max_epochs : int ,
41+ identifier : str = None ,
3842 verbosity : int = None ,
3943 ):
4044 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
@@ -56,9 +60,20 @@ def __init__(
5660 self .step = 0
5761 self .metrics = {}
5862 self .epoch_time = None
63+ self .best_val_metric = 1e10
64+ self .best_model_state = None
5965
6066 self .evaluator = Evaluator ()
6167
68+ self .run_dir = os .getcwd ()
69+
70+ timestamp = torch .tensor (datetime .now ().timestamp ()).to (self .device )
71+ self .timestamp_id = datetime .fromtimestamp (timestamp .int ()).strftime (
72+ "%Y-%m-%d-%H-%M-%S"
73+ )
74+ if identifier :
75+ self .timestamp_id = f"{ self .timestamp_id } -{ identifier } "
76+
6277 if self .train_verbosity :
6378 logging .info (
6479 f"GPU is available: { torch .cuda .is_available ()} , Quantity: { torch .cuda .device_count ()} "
@@ -94,6 +109,7 @@ def from_config(cls, config):
94109 loss = cls ._load_loss (config ["optim" ]["loss" ])
95110
96111 max_epochs = config ["optim" ]["max_epochs" ]
112+ identifier = config ["task" ].get ("identifier" , None )
97113 verbosity = config ["task" ].get ("verbosity" , None )
98114
99115 return cls (
@@ -107,6 +123,7 @@ def from_config(cls, config):
107123 test_loader = test_loader ,
108124 loss = loss ,
109125 max_epochs = max_epochs ,
126+ identifier = identifier ,
110127 verbosity = verbosity ,
111128 )
112129
@@ -180,15 +197,12 @@ def _load_scheduler(scheduler_config, optimizer):
180197 @staticmethod
181198 def _load_loss (loss_config ):
182199 """Loads the loss from either the TorchLossWrapper or custom loss functions in matdeeplearn"""
183- try :
184- loss_type = loss_config ["loss_type" ]
185- # if there are other params for loss type, include in call
186- if loss_config .get ("loss_args" ):
187- return eval (loss_type )(** loss_config ["loss_args" ])
188- else :
189- return eval (loss_type )()
190- except (AttributeError , NameError ):
191- raise NotImplementedError (f"Unknown loss class name: { loss_type } " )
200+ loss_cls = registry .get_loss_class (loss_config ["loss_type" ])
201+ # if there are other params for loss type, include in call
202+ if loss_config .get ("loss_args" ):
203+ return loss_cls (** loss_config ["loss_args" ])
204+ else :
205+ return loss_cls ()
192206
193207 @abstractmethod
194208 def _load_task (self ):
@@ -205,3 +219,67 @@ def validate(self):
205219 @abstractmethod
206220 def predict (self ):
207221 """Implemented by derived classes."""
222+
223+ def update_best_model (self , val_metrics ):
224+ """Updates the best val metric and model, saves the best model, and saves the best model predictions"""
225+ self .best_val_metric = val_metrics [type (self .loss_fn ).__name__ ]["metric" ]
226+ self .best_model_state = copy .deepcopy (self .model .state_dict ())
227+
228+ self .save_model ("best_checkpoint.pt" , val_metrics , False )
229+
230+ logging .debug (
231+ f"Saving prediction results for epoch { self .epoch } to: /results/{ self .timestamp_id } /"
232+ )
233+ self .predict (self .train_loader , "train" )
234+ self .predict (self .val_loader , "val" )
235+ self .predict (self .test_loader , "test" )
236+
237+ def save_model (self , checkpoint_file , val_metrics = None , training_state = True ):
238+ """Saves the model state dict"""
239+
240+ if training_state :
241+ state = {
242+ "epoch" : self .epoch ,
243+ "step" : self .step ,
244+ "state_dict" : self .model .state_dict (),
245+ "optimizer" : self .optimizer .state_dict (),
246+ "scheduler" : self .scheduler .scheduler .state_dict (),
247+ "best_val_metric" : self .best_val_metric ,
248+ }
249+ else :
250+ state = {"state_dict" : self .model .state_dict (), "val_metrics" : val_metrics }
251+
252+ checkpoint_dir = os .path .join (
253+ self .run_dir , "results" , self .timestamp_id , "checkpoint"
254+ )
255+ os .makedirs (checkpoint_dir , exist_ok = True )
256+ filename = os .path .join (checkpoint_dir , checkpoint_file )
257+
258+ torch .save (state , filename )
259+ return filename
260+
261+ def save_results (self , output , filename , node_level_predictions = False ):
262+ results_path = os .path .join (self .run_dir , "results" , self .timestamp_id )
263+ os .makedirs (results_path , exist_ok = True )
264+ filename = os .path .join (results_path , filename )
265+ shape = output .shape
266+
267+ id_headers = ["structure_id" ]
268+ if node_level_predictions :
269+ id_headers += ["node_id" ]
270+ num_cols = (shape [1 ] - len (id_headers )) // 2
271+ headers = id_headers + ["target" ] * num_cols + ["prediction" ] * num_cols
272+
273+ with open (filename , "w" ) as f :
274+ csvwriter = csv .writer (f )
275+ for i in range (0 , len (output )):
276+ if i == 0 :
277+ csvwriter .writerow (headers )
278+ elif i > 0 :
279+ csvwriter .writerow (output [i - 1 , :])
280+ return filename
281+
282+ def load_checkpoint (self ):
283+ """Loads the model from a checkpoint.pt file"""
284+ # TODO: implement this method
285+ pass
0 commit comments