Skip to content

Commit b793a7e

Browse files
committed
Added 4x upscale on basic super resolution model and other minor changes
1 parent aab6e75 commit b793a7e

10 files changed

Lines changed: 46 additions & 39 deletions

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,10 @@ optional arguments:
191191
## Installation Instructions
192192
193193
### Setup Virtual Environment
194-
194+
python3.9 or higher
195195
```bash
196196
python3 -m venv venv
197+
source venv/bin/activate
197198
```
198199
199200
## Installing Required Packages

superresolution/SuperResolution.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
299299

300300
# TODO: improve
301301
if os.path.exists(checkpoint_root_path):
302-
custom_objects = {'PSNRMetric' : PSNRMetric(), 'VGG16Error' : VGG16Error()}
302+
custom_objects = {'PSNRMetric': PSNRMetric(), 'VGG16Error': VGG16Error()}
303303
training_model = tf.keras.models.load_model(checkpoint_root_path, custom_objects=custom_objects)
304304

305305
# Create a callback that saves the model weights
@@ -320,7 +320,7 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
320320
training_callbacks.append(example_result_call_back)
321321

322322
# Debug output of the trained augmented data.
323-
#training_callbacks.append(SaveExampleResultImageCallBack(
323+
# training_callbacks.append(SaveExampleResultImageCallBack(
324324
# args.output_dir,
325325
# training_dataset, args.color_space, fileprefix="trainSuperResolution",
326326
# nth_batch_sample=args.example_nth_batch, grid_size=args.example_nth_batch_grid_size))
@@ -429,7 +429,7 @@ def dcsuperresolution_program(vargs=None):
429429
#
430430
parser.add_argument('--model', dest='model',
431431
default='dcsr',
432-
choices=['cnnsr', 'dcsr', 'dscr-post', 'dscr-pre', 'edsr', 'dcsr-ae', 'dcsr-resnet',
432+
choices=['dcsr', 'dscr-post', 'dscr-pre', 'edsr', 'dcsr-ae', 'dcsr-resnet',
433433
'vdsr'],
434434
help='Set which model type to use.', type=str)
435435
#
@@ -482,7 +482,7 @@ def dcsuperresolution_program(vargs=None):
482482
args.model_filepath = os.path.join(args.output_dir, args.model_filepath)
483483

484484
# Allow override to enable cropping for increase details in the dataset.
485-
override_size: tuple = (512, 512) # TODO fix.
485+
override_size: tuple = (768, 768) # TODO fix.
486486

487487
# Setup Dataset
488488
training_dataset = None

superresolution/models/DCSuperResolution.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
class DCSuperResolutionModel(ModelBase):
1111
def __init__(self):
12+
self.possible_upscale = [2, 4]
1213
self.parser = argparse.ArgumentParser(add_help=False, prog="Basic SuperResolution",
1314
description="Basic Deep Convolutional Super Resolution")
1415
group = self.parser.add_argument_group(self.get_name())
@@ -31,9 +32,15 @@ def load_argument(self) -> argparse.ArgumentParser:
3132
return self.parser
3233

3334
def create_model(self, input_shape, output_shape, **kwargs) -> keras.Model:
35+
scale_factor: int = int(output_shape[0] / input_shape[0])
36+
scale_factor: int = int(output_shape[1] / input_shape[1])
37+
38+
if scale_factor not in self.possible_upscale and scale_factor not in self.possible_upscale:
39+
raise ValueError("Invalid upscale")
40+
3441
# Model Construct Parameters.
3542
regularization: float = kwargs.get("regularization", 0.000001) #
36-
upscale_mode: int = kwargs.get("upscale_mode", 2) #
43+
upscale_mode: int = scale_factor
3744
nr_filters: int = kwargs.get("filters", 64)
3845

3946
#

superresolution/models/SuperResolutionCNN.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
class SuperResolutionModelCNN(ModelBase):
1111
def __init__(self):
12-
self.parser = argparse.ArgumentParser(add_help=False) # , prog="Basic SuperResolution",
13-
# description="Basic Deep Convolutional Super Resolution")
14-
#
12+
self.possible_upscale = [2, 4]
13+
self.parser = argparse.ArgumentParser(add_help=False)
14+
1515
self.parser.add_argument('--regularization', dest='regularization',
1616
type=float,
1717
default=0.000001,
@@ -30,13 +30,20 @@ def load_argument(self) -> argparse.ArgumentParser:
3030
return self.parser
3131

3232
def create_model(self, input_shape, output_shape, **kwargs) -> keras.Model:
33+
scale_factor: int = int(output_shape[0] / input_shape[0])
34+
scale_factor: int = int(output_shape[1] / input_shape[1])
35+
36+
if scale_factor not in self.possible_upscale and scale_factor not in self.possible_upscale:
37+
raise ValueError("Invalid upscale")
38+
3339
# Model Construct Parameters.
3440
regularization: float = kwargs.get("regularization", 0.000001) #
35-
upscale_mode: int = kwargs.get("upscale_mode", 2) #
41+
upscale_mode: int = scale_factor #
42+
num_input_filters: int = kwargs.get("input_filters", 64) #
3643

3744
#
3845
return create_cnn_model(input_shape=input_shape,
39-
output_shape=output_shape, input_filter_size=64, regularization=regularization,
46+
output_shape=output_shape, input_filter_size=num_input_filters, regularization=regularization,
4047
upscale_mode=upscale_mode,
4148
kernel_activation='relu')
4249

@@ -51,6 +58,7 @@ def get_model_interface() -> ModelBase:
5158
def create_cnn_model(input_shape: tuple, output_shape: tuple, input_filter_size: int, regularization: float,
5259
upscale_mode: int,
5360
kernel_activation: str):
61+
5462
use_batch_norm: bool = True
5563
use_bias: bool = True
5664
num_conv_block: int = 3
@@ -63,19 +71,18 @@ def create_cnn_model(input_shape: tuple, output_shape: tuple, input_filter_size:
6371
x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
6472

6573
# Convolutional block
66-
for i in range(0, int(upscale_mode / 2)):
67-
for _ in range(0, num_conv_block):
68-
filter_size = input_filter_size << i
69-
x = layers.Conv2D(filters=filter_size, kernel_size=(3, 3), strides=1, padding='same', use_bias=use_bias,
70-
kernel_initializer=tf.keras.initializers.HeNormal())(x)
71-
if use_batch_norm:
72-
x = layers.BatchNormalization(dtype='float32')(x)
73-
x = create_activation(kernel_activation)(x)
74+
for _ in range(0, num_conv_block):
75+
filter_size = input_filter_size << i
76+
x = layers.Conv2D(filters=filter_size, kernel_size=(3, 3), strides=1, padding='same', use_bias=use_bias,
77+
kernel_initializer=tf.keras.initializers.HeNormal())(x)
78+
if use_batch_norm:
79+
x = layers.BatchNormalization(dtype='float32')(x)
80+
x = create_activation(kernel_activation)(x)
7481

7582
# Output to 3 channel output.
7683
x = layers.Conv2DTranspose(filters=output_channels, kernel_size=(9, 9), strides=(
7784
1, 1), padding='same', use_bias=use_bias, kernel_initializer=tf.keras.initializers.HeNormal(),
78-
bias_initializer=tf.keras.initializers.HeNormal())(x)
85+
bias_initializer=tf.keras.initializers.HeNormal())(x)
7986
x = layers.Activation('tanh')(x)
8087
x = layers.ActivityRegularization(l1=regularization, l2=0)(x)
8188

superresolution/models/SuperResolutionResNet.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,6 @@ def __init__(self):
1919
default=0.001,
2020
help='Set the L1 Regularization applied.')
2121

22-
self.parser.add_argument('--upscale-mode', dest='upscale_mode',
23-
type=str,
24-
choices=[''],
25-
default='',
26-
help='Set the L1 Regularization applied.')
27-
2822
def load_argument(self) -> argparse.ArgumentParser:
2923
"""Load in the file for extracting text."""
3024

superresolution/models/SuperResolutionVDSR.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ class VDSRSuperResolutionModel(ModelBase):
1212
def __init__(self):
1313
self.possible_upscale = [2, 4]
1414

15-
1615
self.parser = argparse.ArgumentParser(add_help=False)
1716
#
1817
self.parser.add_argument('--regularization', dest='regularization',
@@ -35,9 +34,6 @@ def create_model(self, input_shape, output_shape, **kwargs) -> keras.Model:
3534

3635
if scale_factor not in self.possible_upscale and scale_factor not in self.possible_upscale:
3736
raise ValueError("Invalid upscale")
38-
39-
# parser_result = self.parser.parse_known_args(sys.argv[1:])
40-
# Model constructor parameters.
4137

4238
regularization: float = kwargs.get("regularization", 0.00001) #
4339
upscale_mode: int = scale_factor #

superresolution/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from tensorflow.keras import layers
22

33

4-
def create_activation(activation):
4+
def create_activation(activation: str):
55
if activation == "leaky_relu":
66
return layers.LeakyReLU(alpha=0.2, dtype='float32')
77
elif activation == "relu":

superresolution/util/convert_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tensorflow as tf
22

3-
def convert_model(model, dataset=None):
43

4+
def convert_model(model, dataset=None):
55
converter = tf.lite.TFLiteConverter.from_keras_model(model)
66
converter.optimizations = [tf.lite.Optimize.DEFAULT]
77
converter.inference_input_type = tf.float32
@@ -14,7 +14,7 @@ def convert_model(model, dataset=None):
1414
]
1515

1616
converter.post_training_quantize = True
17-
17+
1818
if dataset:
1919
converter.representative_dataset = tf.lite.RepresentativeDataset(
2020
dataset)

superresolution/util/dataProcessing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def DownScaleLayer(data):
150150
interpolation='bilinear',
151151
crop_to_aspect_ratio=False
152152
)])
153-
153+
154154
expectedScale = tf.keras.Sequential([
155155
layers.Resizing(
156156
output_size[0],
@@ -181,7 +181,7 @@ def resize_data(images):
181181

182182
if crop:
183183
dataset = dataset.map(resize_data)
184-
184+
185185
DownScaledDataSet = (
186186
dataset
187187
.map(DownScaleLayer,

superresolution/util/trainingcallback.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ def compute_normalized_PSNR(orignal, data):
1313

1414
class SaveExampleResultImageCallBack(tf.keras.callbacks.Callback):
1515

16-
def __init__(self, dir_path, train_data_subset, color_space: str, nth_batch_sample: int = 512, grid_size: int = 6, fileprefix: str = "SuperResolution",
16+
def __init__(self, dir_path, train_data_subset, color_space: str, nth_batch_sample: int = 512, grid_size: int = 6,
17+
fileprefix: str = "SuperResolution",
1718
**kwargs):
1819
super(tf.keras.callbacks.Callback, self).__init__(**kwargs)
1920

@@ -37,15 +38,16 @@ def on_epoch_begin(self, epoch, logs=None):
3738
def on_epoch_end(self, epoch, logs=None):
3839
fig = show_expect_predicted_result(model=self.model, image_batch_dataset=self.trainSet,
3940
color_space=self.color_space, nr_col=self.grid_size)
40-
fig.savefig(os.path.join(self.dir_path, "{0}{1}.png".format(self.fileprefix,epoch)))
41+
fig.savefig(os.path.join(self.dir_path, "{0}{1}.png".format(self.fileprefix, epoch)))
4142
fig.clf()
4243
plt.close(fig)
4344

4445
def on_train_batch_end(self, batch, logs=None):
4546
if batch % self.nth_batch_sample == 0:
4647
fig = show_expect_predicted_result(model=self.model, image_batch_dataset=self.trainSet,
4748
color_space=self.color_space, nr_col=self.grid_size)
48-
fig.savefig(os.path.join(self.dir_path, "{0}_{1}_{2}.png".format(self.fileprefix, self.current_epoch, batch)))
49+
fig.savefig(
50+
os.path.join(self.dir_path, "{0}_{1}_{2}.png".format(self.fileprefix, self.current_epoch, batch)))
4951
fig.clf()
5052
plt.close(fig)
5153

@@ -154,7 +156,7 @@ def on_train_batch_end(self, batch, logs=None):
154156
def on_epoch_end(self, epoch, logs=None):
155157
super().on_epoch_end(epoch=epoch, logs=logs)
156158

157-
#TODO: add file output.
159+
# TODO: add file output.
158160

159161
# Plot detailed
160162
fig = plotTrainingHistory(self.batch_history, x_label="Batches", y_label="value")

0 commit comments

Comments
 (0)