11from functools import partial
22import tensorflow as tf
33from tensorflow .keras import layers , models
4- from tensorflow .keras .callbacks import ModelCheckpoint
54from tensorflow .keras .optimizers .legacy import SGD
65import tensorflow_probability as tfp
76tfd = tfp .distributions
87tfpl = tfp .layers
98
10- num_samples = 1300667
9+ num_samples = 1510865
1110
1211
1312def kl_approx (q , p , q_tensor ):
@@ -30,7 +29,7 @@ def kl_approx(q, p, q_tensor):
3029
3130def prior (dtype , shape , name , trainable , add_variable_fn ):
3231 """
33- Creates an Independent multivariate normal distribution as a prior.
32+ Creates an customize multivariate normal distribution as a prior.
3433
3534 Args:
3635 dtype: The data type of the distribution's parameters.
@@ -42,8 +41,8 @@ def prior(dtype, shape, name, trainable, add_variable_fn):
4241 Returns:
4342 tfd.Independent: The Independent multivariate normal distribution.
4443 """
45- dist = tfd .MultivariateNormalDiag (loc = 1.2 * tf .ones (shape ),
46- scale_diag = 3.0 * tf .ones (shape ))
44+ dist = tfd .MultivariateNormalDiag (loc = 1.0 * tf .ones (shape ),
45+ scale_diag = 1.5 * tf .ones (shape ))
4746
4847 batch_ndims = tf .size (dist .batch_shape_tensor ())
4948
@@ -163,7 +162,7 @@ def bayes_model(input_shape=(200,200,3)):
163162 inputs = layers .Input (shape = input_shape , name = 'inputs' )
164163
165164 x = get_convolution_reparameterization (16 , 3 , 'swish' )(inputs )
166- x = layers .BatchNormalization ()(x )
165+ x = layers .BatchNormalization (name = 'batchnorm_0' )(x )
167166 x = layers .ReLU ()(x )
168167 x = layers .MaxPooling2D (3 , strides = 2 , padding = 'same' )(x )
169168
@@ -172,48 +171,23 @@ def bayes_model(input_shape=(200,200,3)):
172171
173172 for i in range (num_blocks ):
174173 x = residual_block (x , filters [i ], kernel_size = 3 ,
175- padding = 'same' , activation = tf . nn . silu ,
174+ padding = 'same' , activation = None ,
176175 pool_size = (2 , 2 ), strides = (1 , 1 ),
177176 name = 'residual_block' + str (i ))
178177
179178 x = tf .keras .layers .GlobalMaxPooling2D ()(x )
180- '''
181- x = tfpl.DenseReparameterization(
182- units=64, # This matches the number of units from the Dense layer
183- activation='relu', # Activation can be directly specified here
184- kernel_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
185- kernel_prior_fn=tfpl.default_multivariate_normal_fn,
186- bias_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
187- bias_prior_fn=tfpl.default_multivariate_normal_fn,
188- kernel_divergence_fn=adjusted_divergence_fn,
189- bias_divergence_fn=adjusted_divergence_fn,
190- name='dense_reparam1'
191- )(x)
192-
193- x = tfpl.DenseReparameterization(
194- units=32, # This matches the number of units from the Dense layer
195- activation='sigmoid', # Activation can be directly specified here
196- kernel_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
197- kernel_prior_fn=tfpl.default_multivariate_normal_fn,
198- bias_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
199- bias_prior_fn=tfpl.default_multivariate_normal_fn,
200- kernel_divergence_fn=adjusted_divergence_fn,
201- bias_divergence_fn=adjusted_divergence_fn,
202- name='dense_reparam2'
203- )(x)
204- '''
179+
205180 x = tfpl .DenseReparameterization (
206- units = tfpl .CategoricalMixtureOfOneHotCategorical .params_size (3 , 5 ), activation = None ,
181+ units = tfpl .CategoricalMixtureOfOneHotCategorical .params_size (3 , 1 ), activation = None ,
207182 kernel_posterior_fn = tfpl .default_mean_field_normal_fn (is_singular = False ),
208183 kernel_prior_fn = tfpl .default_multivariate_normal_fn ,
209184 bias_prior_fn = tfpl .default_multivariate_normal_fn ,
210185 bias_posterior_fn = tfpl .default_mean_field_normal_fn (is_singular = False ),
211186 kernel_divergence_fn = divergence_fn ,
212- bias_divergence_fn = divergence_fn ,
213- name = 'dense_reparam3' )(x )
214-
215- x = tfpl .CategoricalMixtureOfOneHotCategorical (event_size = 3 , num_components = 5 , name = 'output' )(x )
216-
187+ name = 'dense_reparam1' )(x )
188+
189+ x = tfpl .CategoricalMixtureOfOneHotCategorical (event_size = 3 , num_components = 1 , name = 'output' )(x )
190+
217191 model = models .Model (inputs , outputs = x , name = 'Rice_BNN' )
218192
219193 return model
0 commit comments