Skip to content

Commit 8246e27

Browse files
committed
Add drop_remainder flag to avoid dropping data when testing
1 parent b00919f commit 8246e27

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

trainNN/iterutils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def train_generator_h5(h5file, dspath, batchsize, seqlen, dtype, iterflag):
119119
else:
120120
yield ds[start_index:end_index]
121121

122-
def train_TFRecord_dataset(dspath, batchsize, dataflag, shuffle=True):
122+
def train_TFRecord_dataset(dspath, batchsize, dataflag, shuffle=True, drop_remainder=True):
123123

124124
#raw_dataset = tf.data.TFRecordDataset(dspath["TFRecord"])
125125

@@ -158,7 +158,7 @@ def _parse_function_wrapper(example_proto):
158158
parsed_dataset = files.interleave(tf.data.TFRecordDataset, num_parallel_calls=tf.data.AUTOTUNE)
159159
if shuffle: parsed_dataset = parsed_dataset.shuffle(100)
160160
parsed_dataset = (parsed_dataset.map(_parse_function_wrapper, num_parallel_calls=tf.data.AUTOTUNE)
161-
.batch(batchsize, drop_remainder=True)
161+
.batch(batchsize, drop_remainder=drop_remainder)
162162
.prefetch(tf.data.AUTOTUNE))
163163

164164
return parsed_dataset

0 commit comments

Comments
 (0)