2323import theano .tensor as T
2424from theano .tensor .shared_randomstreams import RandomStreams
2525
26- from utils .load_conf import load_model ,load_conv_spec ,load_mlp_spec , load_data_spec
26+ from utils .load_conf import load_model ,load_conv_spec ,load_data_spec
2727from io_modules .file_reader import read_dataset
2828from utils .learn_rates import LearningRate
2929from utils .utils import parse_activation
@@ -43,8 +43,10 @@ def runCNN(arg):
4343 else :
4444 model_config = load_model (arg ,'CNN' )
4545
46- conv_config ,conv_layer_config ,mlp_config = load_conv_spec (model_config ['nnet_spec' ],model_config ['batch_size' ],
47- model_config ['input_shape' ])
46+ conv_config ,conv_layer_config ,mlp_config = load_conv_spec (
47+ model_config ['nnet_spec' ],
48+ model_config ['batch_size' ],
49+ model_config ['input_shape' ])
4850
4951 data_spec = load_data_spec (model_config ['data_spec' ],model_config ['batch_size' ]);
5052
@@ -59,43 +61,45 @@ def runCNN(arg):
5961 createDir (model_config ['wdir' ]);
6062 #create working dir
6163
62- #learning rate, batch-size and momentum
63- lrate = LearningRate .get_instance (model_config ['l_rate_method' ],model_config ['l_rate' ]);
6464 batch_size = model_config ['batch_size' ];
65- momentum = model_config ['momentum' ]
66-
6765 cnn = CNN (numpy_rng ,theano_rng ,conv_layer_configs = conv_layer_config , batch_size = batch_size ,
68- n_outs = model_config ['n_outs' ],hidden_layers_sizes = mlp_config ['layers' ], conv_activation = conv_activation ,
69- hidden_activation = hidden_activation ,use_fast = conv_config ['use_fast' ])
66+ n_outs = model_config ['n_outs' ],hidden_layers_sizes = mlp_config ['layers' ],
67+ conv_activation = conv_activation ,hidden_activation = hidden_activation ,
68+ use_fast = conv_config ['use_fast' ])
7069
71- train_sets , train_xy , train_x , train_y = read_dataset (data_spec ['training' ])
72- valid_sets , valid_xy , valid_x , valid_y = read_dataset (data_spec ['validation' ])
70+ if model_config ['processes' ]['finetuning' ]:
71+ #learning rate, batch-size and momentum
72+ lrate = LearningRate .get_instance (model_config ['l_rate_method' ],model_config ['l_rate' ]);
73+ momentum = model_config ['momentum' ]
7374
74- err = fineTunning (cnn ,train_sets ,train_xy ,train_x ,train_y ,
75- valid_sets ,valid_xy ,valid_x ,valid_y ,lrate ,momentum ,batch_size );
76-
77- _cnn2file (cnn .layers [0 :cnn .conv_layer_num ],cnn .layers [cnn .conv_layer_num :], filename = model_config ['output_file' ]);
75+ train_sets , train_xy , train_x , train_y = read_dataset (data_spec ['training' ])
76+ valid_sets , valid_xy , valid_x , valid_y = read_dataset (data_spec ['validation' ])
77+
78+ err = fineTunning (cnn ,train_sets ,train_xy ,train_x ,train_y ,
79+ valid_sets ,valid_xy ,valid_x ,valid_y ,lrate ,momentum ,batch_size );
7880
7981 ####################
8082 ## TESTING ##
8183 ####################
82- try :
83- test_sets , test_xy , test_x , test_y = read_dataset (data_spec ['testing' ])
84- except KeyError :
85- #raise e
86- logger .info ("No testing set:Skiping Testing" );
87- logger .info ("Finshed" )
88- sys .exit (0 )
89-
90- pred ,err = testing (cnn ,test_sets , test_xy , test_x , test_y ,batch_size )
91-
92- ####################
84+ if model_config ['processes' ]['testing' ]:
85+ try :
86+ test_sets , test_xy , test_x , test_y = read_dataset (data_spec ['testing' ])
87+ except KeyError :
88+ #raise e
89+ logger .info ("No testing set:Skiping Testing" );
90+ logger .info ("Finshed" )
91+ sys .exit (0 )
92+
93+ pred ,err = testing (cnn ,test_sets , test_xy , test_x , test_y ,batch_size )
94+
95+ ##########################
9396 ## Export Features ##
94- ####################
95- mlp_layers = cnn .layers [cnn .conv_layer_num :]
96- _file2cnn (cnn .conv_layers ,mlp_layers , filename = model_config ['output_file' ])
97+ ##########################
98+ if model_config ['processes' ]['export_data' ]:
99+ mlp_layers = cnn .layers [cnn .conv_layer_num :]
100+ _file2cnn (cnn .conv_layers ,mlp_layers , filename = model_config ['output_file' ])
101+ exportFeatures (cnn ,model_config ['export_path' ],data_spec ['testing' ])
97102
98- exportFeatures (cnn ,model_config ['export_path' ],data_spec ['testing' ])
99103
100104if __name__ == '__main__' :
101105 runCNN (sys .argv [1 ])
0 commit comments