77tfd = tfp .distributions
88tfpl = tfp .layers
99
10- num_samples = 683620
10+ num_samples = 1300667
1111
1212
1313def kl_approx (q , p , q_tensor ):
@@ -25,20 +25,7 @@ def kl_approx(q, p, q_tensor):
2525
2626 return tf .reduce_mean (q .log_prob (q_tensor ) - p .log_prob (q_tensor ))
2727
28- def divergence_fn (q , p , q_tensor , num_samples = num_samples ):
29- """
30- Normalizes the KL divergence approximation by the number of samples.
31-
32- Args:
33- q (tf.distributions.Distribution): The first distribution.
34- p (tf.distributions.Distribution): The second distribution.
35- q_tensor (tf.Tensor): The tensor to evaluate the log probabilities.
36- num_samples (int): The number of samples for normalization.
37-
38- Returns:
39- tf.Tensor: The normalized KL divergence.
40- """
41- return kl_approx (q , p , q_tensor ) / num_samples
28+ divergence_fn = lambda q , p , q_tensor : kl_approx (q , p , q_tensor ) / num_samples
4229
4330
4431def prior (dtype , shape , name , trainable , add_variable_fn ):
@@ -62,13 +49,11 @@ def prior(dtype, shape, name, trainable, add_variable_fn):
6249
6350 return tfd .Independent (dist , reinterpreted_batch_ndims = batch_ndims )
6451
65- adjusted_divergence_fn = partial (divergence_fn , num_samples = num_samples )
66-
6752
6853def get_convolution_reparameterization (filters , kernel_size , activation , strides = 1 ,
6954 padding = 'SAME' ,
7055 prior = prior ,
71- divergence_fn = adjusted_divergence_fn ,
56+ divergence_fn = divergence_fn ,
7257 name = None ) -> tfpl .Convolution2DReparameterization :
7358 """
7459 Creates a Convolution2DReparameterization layer.
@@ -95,11 +80,11 @@ def get_convolution_reparameterization(filters, kernel_size, activation, strides
9580
9681 kernel_posterior_fn = tfpl .default_mean_field_normal_fn (is_singular = False ),
9782 kernel_prior_fn = prior ,
98- kernel_divergence_fn = adjusted_divergence_fn ,
83+ kernel_divergence_fn = divergence_fn ,
9984
10085 bias_posterior_fn = tfpl .default_mean_field_normal_fn (is_singular = False ),
10186 bias_prior_fn = prior ,
102- bias_divergence_fn = adjusted_divergence_fn ,
87+ bias_divergence_fn = divergence_fn ,
10388 name = name )
10489
10590
@@ -182,7 +167,7 @@ def bayes_model(input_shape=(200,200,3)):
182167 x = layers .ReLU ()(x )
183168 x = layers .MaxPooling2D (3 , strides = 2 , padding = 'same' )(x )
184169
185- num_blocks = 4 # Increase the number of residual blocks
170+ num_blocks = 3 # Increase the number of residual blocks
186171 filters = [32 , 64 , 128 , 256 , 512 ] # Increase the number of filters in each block
187172
188173 for i in range (num_blocks ):
@@ -223,8 +208,8 @@ def bayes_model(input_shape=(200,200,3)):
223208 kernel_prior_fn = tfpl .default_multivariate_normal_fn ,
224209 bias_prior_fn = tfpl .default_multivariate_normal_fn ,
225210 bias_posterior_fn = tfpl .default_mean_field_normal_fn (is_singular = False ),
226- kernel_divergence_fn = adjusted_divergence_fn ,
227- bias_divergence_fn = adjusted_divergence_fn ,
211+ kernel_divergence_fn = divergence_fn ,
212+ bias_divergence_fn = divergence_fn ,
228213 name = 'dense_reparam3' )(x )
229214
230215 x = tfpl .CategoricalMixtureOfOneHotCategorical (event_size = 3 , num_components = 5 , name = 'output' )(x )
0 commit comments