Skip to content

Commit 9516757

Browse files
Aaron HigueraAaron Higuera
authored andcommitted
add prior weights from RiceRNN
1 parent 6283d2b commit 9516757

2 files changed

Lines changed: 21 additions & 3 deletions

File tree

DUNE/RiceBayes/BNN_model.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from functools import partial
23
import tensorflow as tf
34
from tensorflow.keras import layers, models
@@ -48,6 +49,21 @@ def prior(dtype, shape, name, trainable, add_variable_fn):
4849

4950
return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims)
5051

52+
def customize_prior(dtype, shape, name, trainable, add_variable_fn):
53+
"""
54+
Creates an customize normal distribution as a prior.
55+
56+
"""
57+
# Use weights from ResCNN as prior
58+
prior_1stconv2d = np.load('/home/higuera/CNN/conv2d_weights_RiceRes.npy')
59+
60+
mean = prior_1stconv2d
61+
62+
dist = tfd.MultivariateNormalDiag(loc = mean,
63+
scale_diag = 1.5*tf.ones(shape))
64+
batch_ndims = tf.size(dist.batch_shape_tensor())
65+
66+
return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims)
5167

5268
def get_convolution_reparameterization(filters, kernel_size, activation, strides = 1,
5369
padding = 'SAME',
@@ -158,10 +174,12 @@ def bayes_model(input_shape=(200,200,3)):
158174
Returns:
159175
tf.keras.Model: The constructed Keras model.
160176
"""
161-
177+
162178
inputs = layers.Input(shape=input_shape, name='inputs')
163179

164-
x = get_convolution_reparameterization(16, 3, 'swish')(inputs)
180+
x = get_convolution_reparameterization(32, 7, 'swish', strides=2,
181+
prior=customize_prior)(inputs)
182+
165183
x = layers.BatchNormalization(name='batchnorm_0')(x)
166184
x = layers.ReLU()(x)
167185
x = layers.MaxPooling2D(3, strides=2, padding='same')(x)

DUNE/RiceBayes/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def nll(y_true, y_pred):
117117
lr_scheduler = LearningRateSchedulerPlateau(factor=0.5, patience=5, min_lr=1e-6)
118118
history_filename = args.test_name+'_training_history.json'
119119
history_saver = SaveHistoryToFile(history_filename)
120-
early_stopper = EarlyStopping(monitor='val_loss', patience=3, min_delta=0.01, mode='min',
120+
early_stopper = EarlyStopping(monitor='val_loss', patience=3, min_delta=0.001, mode='min',
121121
restore_best_weights=True)
122122

123123
train_generator = DataGenerator(partition['train'], **params)

0 commit comments

Comments
 (0)