Skip to content

Commit 3f11e6a

Browse files
Adding functionality to save the best model
1 parent 2c560bb commit 3f11e6a

3 files changed

Lines changed: 6 additions & 5 deletions

File tree

trainNN/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,10 @@ def train_bichrom(data_paths, outdir, seq_len, bin_size):
108108
probas_out_sc = outdir + '/msc/' + 'test_probs.txt'
109109
records_file_path = outdir + '/metrics'
110110
print(records_file_path)
111+
# save the best msc model
112+
msc.save(outdir + 'full_model.best.hdf5')
111113

112114
evaluate_models(sequence_len=seq_len, path=data_paths['test'],
113115
probas_out_seq=probas_out_seq, probas_out_sc=probas_out_sc,
114116
model_seq=mseq, model_sc=msc,
115117
records_file_path=records_file_path)
116-

trainNN/train_sc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
from __future__ import division
21
import numpy as np
32
import pandas as pd
3+
import tensorflow as tf
44

55
from sklearn.metrics import average_precision_score as auprc
66
from tensorflow.keras.models import Model
77
from tensorflow.keras.layers import Dense, concatenate, Input, LSTM
88
from tensorflow.keras.layers import Conv1D, Reshape, Lambda
99
from tensorflow.keras.optimizers import SGD
10-
import tensorflow.keras.backend as K
1110
from tensorflow.keras.callbacks import Callback
1211
from tensorflow.keras.callbacks import ModelCheckpoint
12+
import tensorflow.keras.backend as K
1313

1414
from iterutils import train_generator
1515

@@ -135,7 +135,7 @@ def transfer(train_path, val_path, basemodel, model, steps_per_epoch,
135135
checkpointer = ModelCheckpoint(records_path + 'model_epoch{epoch}.hdf5',
136136
verbose=1, save_best_only=False)
137137

138-
hist = model.fit_generator(epochs=15, steps_per_epoch=steps_per_epoch,
138+
hist = model.fit_generator(epochs=1, steps_per_epoch=steps_per_epoch,
139139
generator=train_data_generator,
140140
validation_data=validation_data,
141141
callbacks=[precision_recall_history,

trainNN/train_seq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def train(model, train_path, val_path, steps_per_epoch, batch_size,
104104
# earlystop = EarlyStopping(monitor='val_loss', mode='min', verbose=1,
105105
# patience=5)
106106
# training the model..
107-
hist = model.fit_generator(epochs=15, steps_per_epoch=steps_per_epoch,
107+
hist = model.fit_generator(epochs=1, steps_per_epoch=steps_per_epoch,
108108
generator=train_generator,
109109
validation_data=validation_data,
110110
callbacks=[precision_recall_history,

0 commit comments

Comments
 (0)