@@ -63,7 +63,8 @@ def on_epoch_end(self, epoch, logs=None):
6363
6464 with open (self .file_path , 'w' ) as file :
6565 json .dump (self .history , file )
66-
66+
67+
6768
6869def nll (y_true , y_pred ):
6970 """
@@ -110,8 +111,9 @@ def nll(y_true, y_pred):
110111
111112 model .compile (optimizer = sgd_optimizer , loss = nll , metrics = ['accuracy' ])
112113
113- checkpoint_filepath = '/Users/aaronhiguera/HEP/DUNE/Machine-Learning/DUNE/RiceBayes/'
114- checkpoint_callback = ModelCheckpoint (filepath = checkpoint_filepath , save_best_only = False )
114+ checkpoint_filepath = '/home/higuera/RiceBNN/'
115+ checkpoint_callback = ModelCheckpoint (filepath = checkpoint_filepath ,
116+ save_best_only = True , save_weights_only = True , monitor = 'val_loss' )
115117
116118 # Define the learning rate scheduler callback and history saver
117119 lr_scheduler = LearningRateSchedulerPlateau (factor = 0.5 , patience = 5 , min_lr = 1e-6 )
@@ -124,7 +126,7 @@ def nll(y_true, y_pred):
124126 validation_generator = DataGenerator (partition ['validation' ], ** params )
125127
126128 model .fit (train_generator ,validation_data = validation_generator ,
127- epochs = args .num_epochs , callbacks = [lr_scheduler , history_saver , early_stopper ])
129+ epochs = args .num_epochs , callbacks = [lr_scheduler , history_saver , early_stopper , checkpoint_callback ])
128130
129131 # for inferences need to save weights
130132 weights = args .test_name + '.h5'
0 commit comments