Skip to content

Commit 6283d2b

Browse files
Aaron HigueraAaron Higuera
authored andcommitted
keep up with updates
1 parent ff646b4 commit 6283d2b

2 files changed

Lines changed: 20 additions & 44 deletions

File tree

DUNE/RiceBayes/BNN_model.py

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from functools import partial
22
import tensorflow as tf
33
from tensorflow.keras import layers, models
4-
from tensorflow.keras.callbacks import ModelCheckpoint
54
from tensorflow.keras.optimizers.legacy import SGD
65
import tensorflow_probability as tfp
76
tfd = tfp.distributions
87
tfpl = tfp.layers
98

10-
num_samples = 1300667
9+
num_samples = 1510865
1110

1211

1312
def kl_approx(q, p, q_tensor):
@@ -30,7 +29,7 @@ def kl_approx(q, p, q_tensor):
3029

3130
def prior(dtype, shape, name, trainable, add_variable_fn):
3231
"""
33-
Creates an Independent multivariate normal distribution as a prior.
32+
Creates an customize multivariate normal distribution as a prior.
3433
3534
Args:
3635
dtype: The data type of the distribution's parameters.
@@ -42,8 +41,8 @@ def prior(dtype, shape, name, trainable, add_variable_fn):
4241
Returns:
4342
tfd.Independent: The Independent multivariate normal distribution.
4443
"""
45-
dist = tfd.MultivariateNormalDiag(loc=1.2 * tf.ones(shape),
46-
scale_diag=3.0*tf.ones(shape))
44+
dist = tfd.MultivariateNormalDiag(loc = 1.0*tf.ones(shape),
45+
scale_diag = 1.5*tf.ones(shape))
4746

4847
batch_ndims = tf.size(dist.batch_shape_tensor())
4948

@@ -163,7 +162,7 @@ def bayes_model(input_shape=(200,200,3)):
163162
inputs = layers.Input(shape=input_shape, name='inputs')
164163

165164
x = get_convolution_reparameterization(16, 3, 'swish')(inputs)
166-
x = layers.BatchNormalization()(x)
165+
x = layers.BatchNormalization(name='batchnorm_0')(x)
167166
x = layers.ReLU()(x)
168167
x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
169168

@@ -172,48 +171,23 @@ def bayes_model(input_shape=(200,200,3)):
172171

173172
for i in range(num_blocks):
174173
x = residual_block(x, filters[i], kernel_size = 3,
175-
padding = 'same', activation = tf.nn.silu,
174+
padding = 'same', activation = None,
176175
pool_size = (2, 2), strides = (1, 1),
177176
name = 'residual_block'+str(i))
178177

179178
x = tf.keras.layers.GlobalMaxPooling2D()(x)
180-
'''
181-
x = tfpl.DenseReparameterization(
182-
units=64, # This matches the number of units from the Dense layer
183-
activation='relu', # Activation can be directly specified here
184-
kernel_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
185-
kernel_prior_fn=tfpl.default_multivariate_normal_fn,
186-
bias_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
187-
bias_prior_fn=tfpl.default_multivariate_normal_fn,
188-
kernel_divergence_fn=adjusted_divergence_fn,
189-
bias_divergence_fn=adjusted_divergence_fn,
190-
name='dense_reparam1'
191-
)(x)
192-
193-
x = tfpl.DenseReparameterization(
194-
units=32, # This matches the number of units from the Dense layer
195-
activation='sigmoid', # Activation can be directly specified here
196-
kernel_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
197-
kernel_prior_fn=tfpl.default_multivariate_normal_fn,
198-
bias_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
199-
bias_prior_fn=tfpl.default_multivariate_normal_fn,
200-
kernel_divergence_fn=adjusted_divergence_fn,
201-
bias_divergence_fn=adjusted_divergence_fn,
202-
name='dense_reparam2'
203-
)(x)
204-
'''
179+
205180
x = tfpl.DenseReparameterization(
206-
units = tfpl.CategoricalMixtureOfOneHotCategorical.params_size(3, 5), activation = None,
181+
units = tfpl.CategoricalMixtureOfOneHotCategorical.params_size(3, 1), activation = None,
207182
kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
208183
kernel_prior_fn = tfpl.default_multivariate_normal_fn,
209184
bias_prior_fn = tfpl.default_multivariate_normal_fn,
210185
bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
211186
kernel_divergence_fn=divergence_fn,
212-
bias_divergence_fn=divergence_fn,
213-
name = 'dense_reparam3')(x)
214-
215-
x = tfpl.CategoricalMixtureOfOneHotCategorical(event_size = 3, num_components = 5, name = 'output')(x)
216-
187+
name = 'dense_reparam1')(x)
188+
189+
x = tfpl.CategoricalMixtureOfOneHotCategorical(event_size = 3, num_components = 1, name = 'output')(x)
190+
217191
model = models.Model(inputs, outputs=x, name='Rice_BNN')
218192

219193
return model

DUNE/RiceBayes/train.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from generator_class import DataGenerator
1111

1212
import tensorflow as tf
13-
from tensorflow.keras import datasets, layers, models, optimizers, callbacks
14-
from tensorflow.keras.callbacks import ModelCheckpoint
13+
from tensorflow.keras import layers, models, optimizers, callbacks
14+
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
1515
from tensorflow.keras.optimizers.legacy import SGD
1616

1717
#GPU/CPU Selection
@@ -22,7 +22,7 @@ class LearningRateSchedulerPlateau(callbacks.Callback):
2222
'''
2323
Learning rate scheduler
2424
'''
25-
def __init__(self, factor=0.5, patience=5, min_lr=1e-6):
25+
def __init__(self, factor=0.5, patience=5, min_lr=1e-4):
2626
super(LearningRateSchedulerPlateau, self).__init__()
2727
self.factor = factor # Factor by which the learning rate will be reduced
2828
self.patience = patience # Number of epochs with no improvement after which learning rate will be reduced
@@ -81,8 +81,8 @@ def nll(y_true, y_pred):
8181
if __name__ == "__main__":
8282
parser = argparse.ArgumentParser()
8383
parser.add_argument('--num_epochs', type=int, default=1, help='Number of epochs')
84-
parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
85-
parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate')
84+
parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
85+
parser.add_argument('--learning_rate', type=float, default=1e-2, help='Learning rate')
8686
parser.add_argument('--pixel_map_size', type=int, default=200, help='Pixel map size square shape')
8787
parser.add_argument('--pixel_maps', type=str, help='Pre-selected pixel maps ')
8888
parser.add_argument('--test_name', type=str, default='test', help='name of model and plots')
@@ -117,12 +117,14 @@ def nll(y_true, y_pred):
117117
lr_scheduler = LearningRateSchedulerPlateau(factor=0.5, patience=5, min_lr=1e-6)
118118
history_filename = args.test_name+'_training_history.json'
119119
history_saver = SaveHistoryToFile(history_filename)
120+
early_stopper = EarlyStopping(monitor='val_loss', patience=3, min_delta=0.01, mode='min',
121+
restore_best_weights=True)
120122

121123
train_generator = DataGenerator(partition['train'], **params)
122124
validation_generator = DataGenerator(partition['validation'], **params)
123125

124126
model.fit(train_generator,validation_data=validation_generator,
125-
epochs=args.num_epochs, callbacks=[lr_scheduler, history_saver])
127+
epochs=args.num_epochs, callbacks=[lr_scheduler, history_saver, early_stopper])
126128

127129
# for inferences need to save weights
128130
weights = args.test_name+'.h5'

0 commit comments

Comments
 (0)