-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexpression_main.py
More file actions
103 lines (95 loc) · 5.75 KB
/
expression_main.py
File metadata and controls
103 lines (95 loc) · 5.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from expression.expression_model import init_model_resnet50v2, init_model_alexnet, init_model_mobilenet, setup_network, init_model_train_expression, train_model, plot_result_train_model
from preparation.preparation import load_data_set
if __name__ == "__main__":
# init value
img_height, img_width = 224, 224
channels = 3
train_data_path = "./dataset/expression/train/"
validation_data_path = "./dataset/expression/validate/"
check_point_path = "./models/expression/working/"
save_weight_vgg16 = "./models/expression/vgg16/vgg16_expression.h5"
save_weight_vgg19 = "./models/expression/vgg19/vgg19_expression.h5"
save_weight_resnet = "./models/expression/resnet/resnet_expression.h5"
save_weight_alexnet = "./models/expression/alexnet/alexnet_expression.h5"
save_weight_mobilenet = "./models/expression/mobilenet/mobilenet_expression.h5"
batch_size = 32
Epochs = 18
include_top = False
class_num = 8
dropout = 0.2
activation = "softmax"
loss = "categorical_crossentropy"
train_datagen_args = dict(
rotation_range=20,
rescale=1./255,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True
)
# load data set train,test and validate
print("load data...")
train_data = load_data_set(
train_datagen_args, train_data_path, (img_height, img_width), batch_size)
validation_data = load_data_set(
train_datagen_args, validation_data_path, (img_height, img_width), batch_size)
STEP_SIZE_TRAIN = train_data.n//train_data.batch_size
STEP_SIZE_VALID = validation_data.n//validation_data.batch_size
print("load data end...")
# vgg16
print("vgg16 model start...")
# init model
vgg16_model = init_model_train_expression(types="vgg16", include_top=include_top, img_height=img_height, img_width=img_width,
channels=channels, class_num=class_num, layer_num=19, activation=activation, loss=loss, dropout=dropout)
# # train model
# train_vgg16, history_vgg16 = train_model(checkpoint_path=check_point_path+"vgg16_best.h5", save_weight_path=save_weight_vgg16,
# model=vgg16_model, train_data=train_data, validation_data=validation_data, step_size_train=STEP_SIZE_TRAIN, step_size_valid=STEP_SIZE_VALID, epochs_train=Epochs)
# # plot result trian model
# plot_result_train_model(history=history_vgg16,
# model_name="vgg16 accurency")
# print("vgg16 model end...")
# vgg19
print("vgg19 model start...")
# init vgg19
vgg19_model = init_model_train_expression(types="vgg19", include_top=include_top, img_height=img_height, img_width=img_width,
channels=channels, class_num=class_num, layer_num=22, activation=activation, loss=loss, dropout=dropout)
# # train model
# train_vgg19, history_vgg19 = train_model(checkpoint_path=check_point_path+"vgg19_best.h5", save_weight_path=save_weight_vgg19,
# model=vgg19_model, train_data=train_data, validation_data=validation_data, step_size_train=STEP_SIZE_TRAIN, step_size_valid=STEP_SIZE_VALID, epochs_train=Epochs)
# # plot result trian model
# plot_result_train_model(history=history_vgg19,
# model_name="vgg19 accurency")
# print("vgg19 model end...")
# resnet
print("resnet model start...")
# init resnet
resnet_model = init_model_train_expression(types="resnet", include_top=include_top, img_height=img_height, img_width=img_width,
channels=channels, class_num=class_num, layer_num=190, activation=activation, loss=loss, dropout=dropout)
# # train model
# train_resnet, history_resnet = train_model(checkpoint_path=check_point_path, save_weight_path=save_weight_resnet,
# model=resnet_model, train_data=train_data, validation_data=validation_data, step_size_train=STEP_SIZE_TRAIN, step_size_valid=STEP_SIZE_VALID, epochs_train=Epochs)
# # plot result trian model
# plot_result_train_model(history=history_resnet,
# model_name="resnet accurency")
# print("resnet model end...")
# # alexnet
# # init alexnet
# print("alexnet model start...")
# alexnet_model = init_model_alexnet()
# # train model
# train_alexnet, history_alexnet = train_model(checkpoint_path=check_point_path, save_weight_path=save_weight_alexnet, model=alexnet_model,
# train_data=train_data, validation_data=validation_data, step_size_train=STEP_SIZE_TRAIN, step_size_valid=STEP_SIZE_VALID, epochs_train=Epochs)
# plot_result_train_model(history=history_alexnet,
# model_name="alexnet accurency")
# print("alexnet model end...")
# mobilenet v2
print("mobilenetV2 model start...")
# init mobilenet
mobile_model = init_model_train_expression(types="mobilenet", include_top=include_top, img_height=img_height, img_width=img_width,
channels=channels, class_num=class_num, layer_num=154, activation=activation, loss=loss, dropout=dropout)
# # train model
# train_resnet, history_resnet = train_model(checkpoint_path=check_point_path, save_weight_path=save_weight_mobilenet,
# model=mobile_model, train_data=train_data, validation_data=validation_data, step_size_train=STEP_SIZE_TRAIN, step_size_valid=STEP_SIZE_VALID, epochs_train=Epochs)
# # plot result trian model
# plot_result_train_model(history=history_resnet,
# model_name="mobilenet accurency")
# print("mobilenetV2 model end...")