-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_balanced.py
More file actions
120 lines (91 loc) · 3.83 KB
/
create_balanced.py
File metadata and controls
120 lines (91 loc) · 3.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np # linear algebra
import os # accessing directory structure
import skimage
import imageio
from shutil import copyfile
from sklearn.model_selection import train_test_split
from PIL import Image
import cv2
import pickle
# GENERATING TRAINING DATA
def generateData(img_size, max_balanced, balanced_dir):
"""
@param: img_size: size of image
prints failed image paths
@return: trainImg, trainTarget, validImg, validTarget
"""
allImg = []
allTarget = []
catNames = ["Black-grass", "Charlock", "Cleavers", "Common Chickweed", "Common wheat", "Fat Hen", "Loose Silky-bent", "Maize", "Scentless Mayweed", "Shepherds Purse", "Small-flowered Cranesbill", "Sugar beet"]
for cat in catNames:
tv = np.array([0,0,0,0,0,0,0,0,0,0,0,0])
tv[catNames.index(cat)] = 1
assert np.max(tv) == 1 # make sure properly classfifiedd
for i in range (0, max_balanced):
imgPath = os.path.join(balanced_dir, cat + "_" + str(i) + ".png")
if(os.path.isfile(imgPath)):
im_frame = cv2.imread(imgPath)
#resizing the image to img_size, img_size. This is a basic solution to the issue to varying image size
res_im = cv2.resize(im_frame, (img_size, img_size), interpolation=cv2.INTER_LINEAR)
allImg.append(res_im)
allTarget.append(tv)
else:
print(imgPath)
allImg = np.array(allImg)
allTarget = np.array(allTarget)
# train valid split
stratify = np.argmax(allTarget, axis = 1).reshape((allTarget.shape[0], 1))
return train_test_split(allImg, allTarget, test_size=0.1, random_state=13, stratify=allTarget)
def main():
# find max balanced
data_path = os.path.join(".", "data")
dir_list = os.listdir(data_path)
max_balanced = 9999999999999
for dir_ in dir_list:
n = 0
print(dir_, end=': ')
for name in os.listdir(os.path.join(data_path, dir_)):
if os.path.isfile(os.path.join(data_path, dir_, name)):
n = n + 1
print(n)
max_balanced = min(max_balanced, n)
print("---------")
print("Max Balanced:", max_balanced)
# create balanced dataset if doest exist already
balanced_dir = os.path.join(".", 'balanced')
if not os.path.exists(balanced_dir):
print("creating new balanced dataset with", max_balanced, "imgs per class")
os.mkdir(balanced_dir)
for dir_ in dir_list:
n = 0
for name in os.listdir(os.path.join(data_path, dir_)):
src = os.path.join(data_path, dir_, name)
if os.path.isfile(src):
if(n < max_balanced):
dst = os.path.join(balanced_dir, dir_ + "_" + str(n) + ".png")
copyfile(src, dst)
n = n + 1
print("finished creating new balanced dataset")
else:
print("balanced dataset already exists")
print(len(os.listdir(balanced_dir)))
print(12 * 253)
img_size = 128
trainX, validX, trainY, validY = generateData(img_size=img_size, max_balanced=max_balanced, balanced_dir = balanced_dir)
# pickle load
pickle_dir = os.path.join(".", 'balanced_pickled')
if not os.path.exists(pickle_dir):
os.mkdir(pickle_dir)
with open(os.path.join(".", "balanced_pickled", "trainX_" + str(img_size)), "wb") as f:
pickle.dump(trainX, f)
with open(os.path.join(".", "balanced_pickled", "trainY_" + str(img_size)), "wb") as f:
pickle.dump(trainY, f)
with open(os.path.join(".", "balanced_pickled", "validX_" + str(img_size)), "wb") as f:
pickle.dump(validX, f)
with open(os.path.join(".", "balanced_pickled", "validY_" + str(img_size)), "wb") as f:
pickle.dump(validY, f)
dataset = (trainX, trainY, validX, validY)
with open(os.path.join(".", "balanced_pickled", "dataset_" + str(img_size)), "wb") as f:
pickle.dump(dataset, f)
if __name__=='__main__':
main()