99
1010class SuperResolutionModelCNN (ModelBase ):
1111 def __init__ (self ):
12- self .parser = argparse . ArgumentParser ( add_help = False ) # , prog="Basic SuperResolution",
13- # description="Basic Deep Convolutional Super Resolution" )
14- #
12+ self .possible_upscale = [ 2 , 4 ]
13+ self . parser = argparse . ArgumentParser ( add_help = False )
14+
1515 self .parser .add_argument ('--regularization' , dest = 'regularization' ,
1616 type = float ,
1717 default = 0.000001 ,
@@ -30,13 +30,20 @@ def load_argument(self) -> argparse.ArgumentParser:
3030 return self .parser
3131
3232 def create_model (self , input_shape , output_shape , ** kwargs ) -> keras .Model :
33+ scale_factor : int = int (output_shape [0 ] / input_shape [0 ])
34+ scale_factor : int = int (output_shape [1 ] / input_shape [1 ])
35+
36+ if scale_factor not in self .possible_upscale and scale_factor not in self .possible_upscale :
37+ raise ValueError ("Invalid upscale" )
38+
3339 # Model Construct Parameters.
3440 regularization : float = kwargs .get ("regularization" , 0.000001 ) #
35- upscale_mode : int = kwargs .get ("upscale_mode" , 2 ) #
41+ upscale_mode : int = scale_factor #
42+ num_input_filters : int = kwargs .get ("input_filters" , 64 ) #
3643
3744 #
3845 return create_cnn_model (input_shape = input_shape ,
39- output_shape = output_shape , input_filter_size = 64 , regularization = regularization ,
46+ output_shape = output_shape , input_filter_size = num_input_filters , regularization = regularization ,
4047 upscale_mode = upscale_mode ,
4148 kernel_activation = 'relu' )
4249
@@ -51,6 +58,7 @@ def get_model_interface() -> ModelBase:
5158def create_cnn_model (input_shape : tuple , output_shape : tuple , input_filter_size : int , regularization : float ,
5259 upscale_mode : int ,
5360 kernel_activation : str ):
61+
5462 use_batch_norm : bool = True
5563 use_bias : bool = True
5664 num_conv_block : int = 3
@@ -63,19 +71,18 @@ def create_cnn_model(input_shape: tuple, output_shape: tuple, input_filter_size:
6371 x = layers .UpSampling2D (size = (2 , 2 ), interpolation = 'bilinear' )(x )
6472
6573 # Convolutional block
66- for i in range (0 , int (upscale_mode / 2 )):
67- for _ in range (0 , num_conv_block ):
68- filter_size = input_filter_size << i
69- x = layers .Conv2D (filters = filter_size , kernel_size = (3 , 3 ), strides = 1 , padding = 'same' , use_bias = use_bias ,
70- kernel_initializer = tf .keras .initializers .HeNormal ())(x )
71- if use_batch_norm :
72- x = layers .BatchNormalization (dtype = 'float32' )(x )
73- x = create_activation (kernel_activation )(x )
74+ for _ in range (0 , num_conv_block ):
75+ filter_size = input_filter_size << i
76+ x = layers .Conv2D (filters = filter_size , kernel_size = (3 , 3 ), strides = 1 , padding = 'same' , use_bias = use_bias ,
77+ kernel_initializer = tf .keras .initializers .HeNormal ())(x )
78+ if use_batch_norm :
79+ x = layers .BatchNormalization (dtype = 'float32' )(x )
80+ x = create_activation (kernel_activation )(x )
7481
7582 # Output to 3 channel output.
7683 x = layers .Conv2DTranspose (filters = output_channels , kernel_size = (9 , 9 ), strides = (
7784 1 , 1 ), padding = 'same' , use_bias = use_bias , kernel_initializer = tf .keras .initializers .HeNormal (),
78- bias_initializer = tf .keras .initializers .HeNormal ())(x )
85+ bias_initializer = tf .keras .initializers .HeNormal ())(x )
7986 x = layers .Activation ('tanh' )(x )
8087 x = layers .ActivityRegularization (l1 = regularization , l2 = 0 )(x )
8188
0 commit comments