Skip to content

Commit d4bb877

Browse files
committed
add test for batch_size
1 parent bbe3e7f commit d4bb877

1 file changed

Lines changed: 30 additions & 0 deletions

File tree

tests/test_writer_loader.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,36 @@ def test_chop_genome(self):
8383
self.assertTrue(BedTool().from_dataframe(coords_incl).intersect(interval).count()==len(coords_incl))
8484
self.assertTrue(BedTool().from_dataframe(coords_excl).intersect(interval).count()==0)
8585

86+
def test_write_load_batched_target_bam(self):
87+
coords = pd.DataFrame({
88+
'chrom': ["chr19", "chr19"],
89+
'start': [0, 3],
90+
'end': [5, 8],
91+
'label': [0, 1],
92+
'score': [".", "."],
93+
'strand': ["+", "+"]
94+
})
95+
huge_coords = pd.concat([coords] * 6000, axis=0).reset_index()
96+
dump_data_webdataset(huge_coords,
97+
genome_fasta='data/sample.fa',
98+
bigwig_filelist=['data/sample.bw'],
99+
target_bam='data/sample.bam',
100+
outdir=self.tempdir,
101+
outprefix='test',
102+
compress=True,
103+
numProcessors=2,
104+
batch_size=128)
105+
self.assertIsFile(os.path.join(self.tempdir, "test_0.tar.gz"))
106+
ds = wds.DataPipeline(
107+
wds.SimpleShardList([os.path.join(self.tempdir, "test_0.tar.gz")]),
108+
wds.tarfile_to_samples(),
109+
wds.decode(),
110+
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
111+
wds.batched(2)
112+
)
113+
seq, chrom, target, label = next(iter(ds))
114+
self.assertEqual(seq.shape, (2, 128, 4, 5))
115+
86116
def test_write_load_target_bam(self):
87117
coords = pd.DataFrame({
88118
'chrom': ["chr19", "chr19"],

0 commit comments

Comments
 (0)