Skip to content

Commit 35b4107

Browse files
committed
Added:checkpointing of model
1 parent d53e9f2 commit 35b4107

3 files changed

Lines changed: 23 additions & 3 deletions

File tree

sample_config/MNIST/CNN/model_conf.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"input_shape" : [1,28,28],
99
"batch_size" : 256,
1010
"n_outs" : 10,
11+
"save_feq":4,
1112
"finetune_params" : {
1213
"method" : "E",
1314
"momentum" : 0.5,

src/pythonDnn/run/__init__.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import time,numpy,os
2+
import threading
23
import logging
34
logger = logging.getLogger(__name__)
45

@@ -44,7 +45,7 @@ def testing(nnetModel,data_spec,saveLabel=True,outFile='test.out'):
4445
saveLabels(nnetModel,outFile,data_spec['testing'])
4546

4647

47-
def _fineTunning(nnetModel,train_sets,valid_sets,lrate,momentum):
48+
def _fineTunning(nnetModel,train_sets,valid_sets,lrate,momentum,saveFeq=0,outFile=''):
4849

4950
train_xy = train_sets.shared_xy
5051
train_x = train_sets.shared_x
@@ -85,6 +86,17 @@ def valid_score():
8586
logger.debug('Training batch %d error %f',batch_index, numpy.mean(train_error))
8687
train_sets.read_next_partition_data()
8788
logger.info('Fine Tunning:epoch %d, training error %f',lrate.epoch, numpy.mean(train_error));
89+
90+
#savemodel
91+
savethread=None
92+
93+
if(saveFeq!=0 and lrate.epoch%saveFeq==0):
94+
logger.info('Saving partial model to:%s',outFile)
95+
savethread=threading.Thread(target=nnetModel.save,kwargs={'filename':outFile});
96+
savethread.start()
97+
#nnetModel.save(filename=outFile);
98+
99+
88100
train_sets.initialize_read()
89101

90102
valid_error = valid_score()
@@ -93,6 +105,9 @@ def valid_score():
93105
logger.info('Fine Tunning:epoch %d, validation error %f',lrate.epoch, valid_error);
94106
lrate.get_next_rate(current_error = 100 * valid_error)
95107

108+
if savethread!=None:
109+
savethread.join();
110+
96111
end_time = time.clock()
97112

98113
logger.info('Best validation error %f',best_validation_loss)
@@ -111,6 +126,8 @@ def fineTunning(nnetModel,model_config,data_spec):
111126
logger.info("No validation/training set:Skiping Fine tunning");
112127
else:
113128
try:
129+
outFile=model_config['output_file']
130+
saveFeq=model_config['save_feq']
114131
finetune_config = model_config['finetune_params']
115132
momentum = finetune_config['momentum']
116133
lrate = LearningRate.get_instance(finetune_config);
@@ -119,8 +136,7 @@ def fineTunning(nnetModel,model_config,data_spec):
119136
logger.critical("Fine tunning Paramters Missing")
120137
exit(2)
121138

122-
123-
_fineTunning(nnetModel,train_sets,valid_sets,lrate,momentum)
139+
_fineTunning(nnetModel,train_sets,valid_sets,lrate,momentum,saveFeq,outFile)
124140

125141

126142
def exportFeatures(nnetModel,model_config,data_spec):

src/pythonDnn/utils/load_conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def load_model(input_file,nnetType=None):
3939
if not data.has_key('random_seed') or not type(data['random_seed']) is int:
4040
data['random_seed'] = None
4141

42+
if not data.has_key('save_feq') or not type(data['save_feq']) is int:
43+
data['save_feq'] = 0
44+
4245
if data.has_key('n_ins') or data.has_key('input_shape'):
4346
pass
4447
else:

0 commit comments

Comments
 (0)