Skip to content

Commit ff646b4

Browse files
Aaron HigueraAaron Higuera
authored andcommitted
update model and training script
1 parent 9ccdc12 commit ff646b4

2 files changed

Lines changed: 10 additions & 25 deletions

File tree

DUNE/RiceBayes/BNN_model.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
tfd = tfp.distributions
88
tfpl = tfp.layers
99

10-
num_samples = 683620
10+
num_samples = 1300667
1111

1212

1313
def kl_approx(q, p, q_tensor):
@@ -25,20 +25,7 @@ def kl_approx(q, p, q_tensor):
2525

2626
return tf.reduce_mean(q.log_prob(q_tensor) - p.log_prob(q_tensor))
2727

28-
def divergence_fn(q, p, q_tensor, num_samples=num_samples):
29-
"""
30-
Normalizes the KL divergence approximation by the number of samples.
31-
32-
Args:
33-
q (tf.distributions.Distribution): The first distribution.
34-
p (tf.distributions.Distribution): The second distribution.
35-
q_tensor (tf.Tensor): The tensor to evaluate the log probabilities.
36-
num_samples (int): The number of samples for normalization.
37-
38-
Returns:
39-
tf.Tensor: The normalized KL divergence.
40-
"""
41-
return kl_approx(q, p, q_tensor) / num_samples
28+
divergence_fn = lambda q, p, q_tensor : kl_approx(q, p, q_tensor) / num_samples
4229

4330

4431
def prior(dtype, shape, name, trainable, add_variable_fn):
@@ -62,13 +49,11 @@ def prior(dtype, shape, name, trainable, add_variable_fn):
6249

6350
return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims)
6451

65-
adjusted_divergence_fn = partial(divergence_fn, num_samples=num_samples)
66-
6752

6853
def get_convolution_reparameterization(filters, kernel_size, activation, strides = 1,
6954
padding = 'SAME',
7055
prior = prior,
71-
divergence_fn = adjusted_divergence_fn,
56+
divergence_fn = divergence_fn,
7257
name = None) -> tfpl.Convolution2DReparameterization:
7358
"""
7459
Creates a Convolution2DReparameterization layer.
@@ -95,11 +80,11 @@ def get_convolution_reparameterization(filters, kernel_size, activation, strides
9580

9681
kernel_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
9782
kernel_prior_fn=prior,
98-
kernel_divergence_fn=adjusted_divergence_fn,
83+
kernel_divergence_fn=divergence_fn,
9984

10085
bias_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
10186
bias_prior_fn=prior,
102-
bias_divergence_fn=adjusted_divergence_fn,
87+
bias_divergence_fn=divergence_fn,
10388
name=name)
10489

10590

@@ -182,7 +167,7 @@ def bayes_model(input_shape=(200,200,3)):
182167
x = layers.ReLU()(x)
183168
x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
184169

185-
num_blocks = 4 # Increase the number of residual blocks
170+
num_blocks = 3 # Increase the number of residual blocks
186171
filters = [32, 64, 128, 256, 512] # Increase the number of filters in each block
187172

188173
for i in range(num_blocks):
@@ -223,8 +208,8 @@ def bayes_model(input_shape=(200,200,3)):
223208
kernel_prior_fn = tfpl.default_multivariate_normal_fn,
224209
bias_prior_fn = tfpl.default_multivariate_normal_fn,
225210
bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
226-
kernel_divergence_fn=adjusted_divergence_fn,
227-
bias_divergence_fn=adjusted_divergence_fn,
211+
kernel_divergence_fn=divergence_fn,
212+
bias_divergence_fn=divergence_fn,
228213
name = 'dense_reparam3')(x)
229214

230215
x = tfpl.CategoricalMixtureOfOneHotCategorical(event_size = 3, num_components = 5, name = 'output')(x)

DUNE/RiceBayes/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,6 @@ def nll(y_true, y_pred):
128128
weights = args.test_name+'.h5'
129129
model.save_weights(weights)
130130
## !!!!!need to fix this!!!
131-
#complete_model = args.test_name
132-
#model.save(complete_model)
131+
complete_model = args.test_name
132+
model.save(complete_model)
133133

0 commit comments

Comments
 (0)