|
| 1 | +import numpy as np |
1 | 2 | from functools import partial |
2 | 3 | import tensorflow as tf |
3 | 4 | from tensorflow.keras import layers, models |
@@ -48,6 +49,21 @@ def prior(dtype, shape, name, trainable, add_variable_fn): |
48 | 49 |
|
49 | 50 | return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims) |
50 | 51 |
|
| 52 | +def customize_prior(dtype, shape, name, trainable, add_variable_fn): |
| 53 | + """ |
| 54 | + Creates an customize normal distribution as a prior. |
| 55 | +
|
| 56 | + """ |
| 57 | + # Use weights from ResCNN as prior |
| 58 | + prior_1stconv2d = np.load('/home/higuera/CNN/conv2d_weights_RiceRes.npy') |
| 59 | + |
| 60 | + mean = prior_1stconv2d |
| 61 | + |
| 62 | + dist = tfd.MultivariateNormalDiag(loc = mean, |
| 63 | + scale_diag = 1.5*tf.ones(shape)) |
| 64 | + batch_ndims = tf.size(dist.batch_shape_tensor()) |
| 65 | + |
| 66 | + return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims) |
51 | 67 |
|
52 | 68 | def get_convolution_reparameterization(filters, kernel_size, activation, strides = 1, |
53 | 69 | padding = 'SAME', |
@@ -158,10 +174,12 @@ def bayes_model(input_shape=(200,200,3)): |
158 | 174 | Returns: |
159 | 175 | tf.keras.Model: The constructed Keras model. |
160 | 176 | """ |
161 | | - |
| 177 | + |
162 | 178 | inputs = layers.Input(shape=input_shape, name='inputs') |
163 | 179 |
|
164 | | - x = get_convolution_reparameterization(16, 3, 'swish')(inputs) |
| 180 | + x = get_convolution_reparameterization(32, 7, 'swish', strides=2, |
| 181 | + prior=customize_prior)(inputs) |
| 182 | + |
165 | 183 | x = layers.BatchNormalization(name='batchnorm_0')(x) |
166 | 184 | x = layers.ReLU()(x) |
167 | 185 | x = layers.MaxPooling2D(3, strides=2, padding='same')(x) |
|
0 commit comments