Skip to content

Commit e6db101

Browse files
Aaron HigueraAaron Higuera
authored andcommitted
update model
1 parent 9516757 commit e6db101

2 files changed

Lines changed: 8 additions & 6 deletions

File tree

DUNE/RiceBayes/BNN_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
tfd = tfp.distributions
88
tfpl = tfp.layers
99

10-
num_samples = 1510865
10+
num_samples = 2023317
1111

1212

1313
def kl_approx(q, p, q_tensor):
@@ -184,7 +184,7 @@ def bayes_model(input_shape=(200,200,3)):
184184
x = layers.ReLU()(x)
185185
x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
186186

187-
num_blocks = 3 # Increase the number of residual blocks
187+
num_blocks = 4 # Increase the number of residual blocks
188188
filters = [32, 64, 128, 256, 512] # Increase the number of filters in each block
189189

190190
for i in range(num_blocks):

DUNE/RiceBayes/train.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def on_epoch_end(self, epoch, logs=None):
6363

6464
with open(self.file_path, 'w') as file:
6565
json.dump(self.history, file)
66-
66+
67+
6768

6869
def nll(y_true, y_pred):
6970
"""
@@ -110,8 +111,9 @@ def nll(y_true, y_pred):
110111

111112
model.compile(optimizer=sgd_optimizer, loss=nll, metrics=['accuracy'])
112113

113-
checkpoint_filepath = '/Users/aaronhiguera/HEP/DUNE/Machine-Learning/DUNE/RiceBayes/'
114-
checkpoint_callback = ModelCheckpoint(filepath=checkpoint_filepath, save_best_only=False)
114+
checkpoint_filepath = '/home/higuera/RiceBNN/'
115+
checkpoint_callback = ModelCheckpoint(filepath=checkpoint_filepath,
116+
save_best_only=True, save_weights_only=True, monitor='val_loss')
115117

116118
# Define the learning rate scheduler callback and history saver
117119
lr_scheduler = LearningRateSchedulerPlateau(factor=0.5, patience=5, min_lr=1e-6)
@@ -124,7 +126,7 @@ def nll(y_true, y_pred):
124126
validation_generator = DataGenerator(partition['validation'], **params)
125127

126128
model.fit(train_generator,validation_data=validation_generator,
127-
epochs=args.num_epochs, callbacks=[lr_scheduler, history_saver, early_stopper])
129+
epochs=args.num_epochs, callbacks=[lr_scheduler, history_saver, early_stopper, checkpoint_callback])
128130

129131
# for inferences need to save weights
130132
weights = args.test_name+'.h5'

0 commit comments

Comments
 (0)