1+ import csv
12import logging
3+ import os
24from abc import ABC , abstractmethod
5+ from datetime import datetime
36
47import torch
58import torch .optim as optim
1720from matdeeplearn .common .registry import registry
1821from matdeeplearn .models .base_model import BaseModel
1922from matdeeplearn .modules .evaluator import Evaluator
20- from matdeeplearn .modules .loss import *
2123from matdeeplearn .modules .scheduler import LRScheduler
2224
2325
@@ -35,6 +37,7 @@ def __init__(
3537 test_loader : DataLoader ,
3638 loss : nn .Module ,
3739 max_epochs : int ,
40+ identifier : str = None ,
3841 verbosity : int = None ,
3942 ):
4043 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
@@ -56,9 +59,19 @@ def __init__(
5659 self .step = 0
5760 self .metrics = {}
5861 self .epoch_time = None
62+ self .best_val_metric = 1e10
5963
6064 self .evaluator = Evaluator ()
6165
66+ self .run_dir = os .getcwd ()
67+
68+ timestamp = torch .tensor (datetime .now ().timestamp ()).to (self .device )
69+ self .timestamp_id = datetime .fromtimestamp (timestamp .int ()).strftime (
70+ "%Y-%m-%d-%H-%M-%S"
71+ )
72+ if identifier :
73+ self .timestamp_id = f"{ self .timestamp_id } -{ identifier } "
74+
6275 if self .train_verbosity :
6376 logging .info (
6477 f"GPU is available: { torch .cuda .is_available ()} , Quantity: { torch .cuda .device_count ()} "
@@ -94,6 +107,7 @@ def from_config(cls, config):
94107 loss = cls ._load_loss (config ["optim" ]["loss" ])
95108
96109 max_epochs = config ["optim" ]["max_epochs" ]
110+ identifier = config ["task" ].get ("identifier" , None )
97111 verbosity = config ["task" ].get ("verbosity" , None )
98112
99113 return cls (
@@ -107,6 +121,7 @@ def from_config(cls, config):
107121 test_loader = test_loader ,
108122 loss = loss ,
109123 max_epochs = max_epochs ,
124+ identifier = identifier ,
110125 verbosity = verbosity ,
111126 )
112127
@@ -180,15 +195,12 @@ def _load_scheduler(scheduler_config, optimizer):
180195 @staticmethod
181196 def _load_loss (loss_config ):
182197 """Loads the loss from either the TorchLossWrapper or custom loss functions in matdeeplearn"""
183- try :
184- loss_type = loss_config ["loss_type" ]
185- # if there are other params for loss type, include in call
186- if loss_config .get ("loss_args" ):
187- return eval (loss_type )(** loss_config ["loss_args" ])
188- else :
189- return eval (loss_type )()
190- except (AttributeError , NameError ):
191- raise NotImplementedError (f"Unknown loss class name: { loss_type } " )
198+ loss_cls = registry .get_loss_class (loss_config ["loss_type" ])
199+ # if there are other params for loss type, include in call
200+ if loss_config .get ("loss_args" ):
201+ return loss_cls (** loss_config ["loss_args" ])
202+ else :
203+ return loss_cls ()
192204
193205 @abstractmethod
194206 def _load_task (self ):
@@ -205,3 +217,23 @@ def validate(self):
205217 @abstractmethod
206218 def predict (self ):
207219 """Implemented by derived classes."""
220+
221+ def save_results (self , output , filename , node_level_predictions = False ):
222+ results_path = os .path .join (self .run_dir , "results" , self .timestamp_id )
223+ os .makedirs (results_path , exist_ok = True )
224+ filename = os .path .join (results_path , filename )
225+ shape = output .shape
226+
227+ id_headers = ["structure_id" ]
228+ if node_level_predictions :
229+ id_headers += ["node_id" ]
230+ num_cols = (shape [1 ] - len (id_headers )) // 2
231+ headers = id_headers + ["target" ] * num_cols + ["prediction" ] * num_cols
232+
233+ with open (filename , "w" ) as f :
234+ csvwriter = csv .writer (f )
235+ for i in range (0 , len (output )):
236+ if i == 0 :
237+ csvwriter .writerow (headers )
238+ elif i > 0 :
239+ csvwriter .writerow (output [i - 1 , :])
0 commit comments