-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_utils.py
More file actions
195 lines (172 loc) · 7.66 KB
/
data_utils.py
File metadata and controls
195 lines (172 loc) · 7.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
Contains useful functions for the PyTorch model, class definition for the data
pipeline, loading and for generating the results.
"""
import glob
import random
from os.path import exists
import PIL
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader, Dataset
# import matplotlib.pyplot as plt
class InputSequence(Dataset):
def __init__(self, path, image_shape, masks=False,
seq_length=2, step=1, aug=False, channels=1):
self.image_shape = image_shape
self.folder_path = path
self.masks = masks
self.seq_length = seq_length
self.aug = aug
self.channels = channels
self.step = step
self.folders = glob.glob("v*", root_dir=self.folder_path)
if masks:
self.sequences = self.generate_masks()
else:
self.sequences = self.generate_sequences()
self.dataset_len = len(self.sequences)
def __len__(self):
return self.dataset_len
def __getitem__(self, index):
p1, p2, p3, p4, p5 = self.sequences[index]
img5 = self.fetch_image(p5)
normal = T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if self.seq_length == 2:
img4 = self.fetch_image(p4)
img4, img5 = self.transform(img4, img5)
if self.channels == 3:
input_images = torch.stack((img4, img4, img4), dim=0)
input_images = normal(input_images)
else:
input_images = img4[None] # add channel dimension
if self.seq_length == 3:
img4 = self.fetch_image(p4)
img3 = self.fetch_image(p3)
img3, img4, img5 = self.transform(img3, img4, img5)
if self.channels == 3: # Deal with pretrained RGB models
img3 = torch.stack((img3, img3, img3), dim=0)
img4 = torch.stack((img4, img4, img4), dim=0)
# Normalise for pretrained models
img3 = normal(img3)
img4 = normal(img4)
input_images = torch.cat((img3, img4), dim=0)
else:
# add channel dim and cat along channel
input_images = torch.cat((img3[None], img4[None]), dim=0)
if self.seq_length == 5:
img4 = self.fetch_image(p4)
img3 = self.fetch_image(p3)
img2 = self.fetch_image(p2)
img1 = self.fetch_image(p1)
img1, img2, img3, img4, img5 = self.transform(img1, img2, img3, img4, img5)
input_images = torch.cat((img1[None], img2[None], img3[None], img4[None]), dim=0)
if self.masks:
img5 = torch.stack((img5, 1 - img5), dim=0)
else:
img5 = img5[None] # Add a channel dimension
return (input_images, img5)
def fetch_image(self, path):
# return cv2.imread(self.folder_path + "/" + path, cv2.IMREAD_GRAYSCALE)
return PIL.Image.open(self.folder_path + "/" + path)
def check_seq(self, *args, step=1):
# ensure all images are same timestep apart
for img1, img2 in zip(args, args[1:]):
if (int(img2[3:-4]) - int(img1[3:-4]) != step):
return False
return True
def generate_sequences(self):
# Generate sequence of 5 images with same timestep,
# if only using a seq of length 3, we simply ignore images 1 & 2 later.
# Avoids extra complications in ensuring the sequences tested on are
# the same
sequences = {}
counter = 0
for folder in self.folders:
files = glob.glob(f"{folder}/*.png", root_dir=self.folder_path)
for n in range(self.step):
f_temp = sorted(files)[n::self.step] # take every Nth (N=step) element
for (img1, img2, img3, img4, img5) in \
zip(f_temp, f_temp[1:], f_temp[2:], f_temp[3:], f_temp[4:]):
if self.check_seq(img1, img2, img3, img4, img5, step=self.step):
sequences[counter] = (img1, img2, img3, img4, img5)
counter += 1
return sequences
def generate_masks(self):
sequences = {}
counter = 0
for folder in self.folders:
files = glob.glob(f"{folder}/*.png", root_dir=self.folder_path)
for n in range(self.step):
f_temp = sorted(files)[n::self.step] # take every Nth (N=step) element
for (img1, img2, img3, img4, img5) in \
zip(f_temp, f_temp[1:], f_temp[2:], f_temp[3:], f_temp[4:]):
if self.check_seq(img1, img2, img3, img4, img5, step=self.step):
img5 = "/masks/" + img4[3:] # get mask of img4
if exists(self.folder_path + img5):
sequences[counter] = (img1, img2, img3, img4, img5)
counter += 1
return sequences
def transform(self, *args):
images = []
norm = T.Normalize(mean=0.505, std=0.145)
if not self.aug:
resize = T.Resize(size=self.image_shape, antialias=False, interpolation=T.InterpolationMode.NEAREST)
for image in args:
temp = TF.to_tensor(image)
#temp = norm(temp)
images.append(resize(temp)[0])
return images
hflip = random.random()
vflip = random.random()
i, j, h, w = T.RandomResizedCrop(size=self.image_shape, antialias=False).get_params(
TF.to_tensor(args[0]), scale=[0.5, 1.0], ratio=[0.75, 1.25])
d = T.RandomRotation.get_params(degrees=[-30, 30])
for image in args:
# Transform to tensor
image = TF.to_tensor(image)
#image = norm(image)
# Random crop
image = TF.crop(image, i, j, h, w)
# Random horizontal flip
if hflip > 0.5:
image = TF.hflip(image)
# Random vertical flip
if vflip > 0.5:
image = TF.vflip(image)
# Random Rotation
image = TF.rotate(image, angle=d)
# Random Jitter
#image = rand_jitter.forward(image)
# Resize
resize = T.Resize(size=self.image_shape, antialias=False, interpolation=T.InterpolationMode.NEAREST)
image = resize(image)
images.append(image[0]) # tensors have been formated to [1, x, y]
return images
def load_data(path, image_shape, batch_size=10, shuffle=True,
masks=False, seq_length=3, step=1, aug=False, channels=1):
dataset = InputSequence(path, image_shape, masks, seq_length, step, aug, channels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
num_workers=18, persistent_workers=True,
pin_memory=True)
return dataloader
# def show_samples(dataloader, num_samples=5):
# fig, ax = plt.subplots(num_samples,
# 3,
# gridspec_kw={'wspace': 0, 'hspace': 0},
# subplot_kw={'xticks': [], 'yticks': []})
#
# for i, (samples, truth) in enumerate(dataloader):
# # enumerate delivers a batch, just pick the first in the batch
# ax[i, 0].imshow(samples[0][0].numpy()) # first channel
# ax[i, 1].imshow(samples[0][1].numpy()) # second channel
# ax[i, 2].imshow(truth[0][0].numpy())
#
# if i == (num_samples - 1):
# break
# fig.suptitle("Sample images from dataset")
# # fig.supxlabel("1st, 2nd and 3rd Image from Sequence")
# # fig.supylabel("Samples")
# plt.show()