Skip to content

Commit c241753

Browse files
committed
Add tests for transform, bed loader, lightning datamodule
1 parent e55a3f6 commit c241753

2 files changed

Lines changed: 172 additions & 0 deletions

File tree

data/sample.bed

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
chr19 0 5 0 . +
2+
chr19 3 8 1 . +

tests/test_writer_loader.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import os
2+
import pandas as pd
3+
from seqchromloader import SeqChromDatasetByBed, SeqChromDatasetByWds, SeqChromDataModule
4+
from seqchromloader import dump_data_webdataset
5+
6+
import unittest
7+
import tempfile
8+
import shutil
9+
import pathlib as pl
10+
import webdataset as wds
11+
12+
class Test(unittest.TestCase):
13+
def setUp(self) -> None:
14+
pass
15+
16+
def tearDown(self) -> None:
17+
pass
18+
19+
@classmethod
20+
def setUpClass(cls) -> None:
21+
cls.tempdir = tempfile.mkdtemp()
22+
23+
@classmethod
24+
def tearDownClass(cls) -> None:
25+
shutil.rmtree(cls.tempdir)
26+
27+
def assertIsFile(self, path):
28+
if not pl.Path(path).resolve().is_file():
29+
raise AssertionError("File does not exist: %s" % str(path))
30+
31+
def test_writer(self):
32+
coords = pd.DataFrame({
33+
'chrom': ["chr19", "chr19"],
34+
'start': [0, 3],
35+
'end': [5, 8],
36+
'label': [0, 1],
37+
'score': [".", "."],
38+
'strand': ["+", "+"]
39+
})
40+
huge_coords = pd.concat([coords] * 5000, axis=0).reset_index()
41+
dump_data_webdataset(huge_coords,
42+
genome_fasta='data/sample.fa',
43+
bigwig_filelist=['data/sample.bw'],
44+
target_bam='data/sample.bam',
45+
outdir=self.tempdir,
46+
outprefix='test',
47+
compress=True,
48+
numProcessors=5)
49+
self.assertIsFile(os.path.join(self.tempdir, "test_0.tar.gz"))
50+
51+
ds = wds.DataPipeline(
52+
wds.SimpleShardList([os.path.join(self.tempdir, "test_0.tar.gz")]),
53+
wds.tarfile_to_samples(),
54+
wds.decode(),
55+
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
56+
wds.batched(2)
57+
)
58+
seq, chrom, target, label = next(iter(ds))
59+
self.assertEqual(seq[1,0,4].item(), 1.0)
60+
self.assertEqual(chrom[0,0,3].item(), 999.0)
61+
self.assertEqual(target[0].item(), 2.0)
62+
self.assertEqual(label[1].item(), 1)
63+
64+
def test_wds_loader(self):
65+
it = iter(SeqChromDatasetByWds(["data/test_0.tar.gz"], dataloader_kws={"batch_size":3}))
66+
seq, chrom, target, label = next(it)
67+
68+
self.assertEqual(seq[0,0,3].item(), 1.0)
69+
self.assertEqual(chrom[0,0,3].item(), 4.0)
70+
self.assertEqual(target[0].item(), 0.0)
71+
self.assertEqual(label[1].item(), 1)
72+
73+
def test_wds_loader_transform(self):
74+
it = iter(SeqChromDatasetByWds(["data/test_0.tar.gz"],
75+
transforms={"seq": test_seq_transform,
76+
"chrom": test_chrom_transform,
77+
"target": test_target_transform},
78+
dataloader_kws={"batch_size":3}))
79+
seq, chrom, target, label = next(it)
80+
81+
self.assertEqual(seq[0,0,3].item(), 2.0)
82+
self.assertAlmostEqual(chrom[0,0,3].item(), 4.0/3)
83+
self.assertEqual(target[0].item(), 0.0)
84+
self.assertEqual(label[1].item(), 1)
85+
86+
def test_bed_loader(self):
87+
it = iter(SeqChromDatasetByBed(
88+
bed="data/sample.bed",
89+
genome_fasta="data/sample.fa",
90+
bigwig_filelist=["data/sample.bw"],
91+
target_bam="data/sample.bam",
92+
dataloader_kws={"batch_size":2,
93+
"shuffle":False}
94+
))
95+
seq, chrom, target, label = next(it)
96+
self.assertEqual(seq[1,0,4].item(), 1.0)
97+
self.assertEqual(chrom[0,0,3].item(), 999.0)
98+
self.assertEqual(target[0].item(), 2.0)
99+
self.assertEqual(label[1].item(), 1)
100+
101+
def test_bed_loader_transform(self):
102+
103+
it = iter(SeqChromDatasetByBed(
104+
bed="data/sample.bed",
105+
genome_fasta="data/sample.fa",
106+
bigwig_filelist=["data/sample.bw"],
107+
target_bam="data/sample.bam",
108+
transforms={"seq": test_seq_transform,
109+
"chrom": test_chrom_transform,
110+
"target": test_target_transform},
111+
dataloader_kws={"batch_size":2,
112+
"shuffle":False}
113+
))
114+
seq, chrom, target, label = next(it)
115+
self.assertEqual(seq[1,0,4].item(), 2.0)
116+
self.assertEqual(chrom[0,0,3].item(), 333.0)
117+
self.assertEqual(target[0].item(), 6.0)
118+
self.assertEqual(label[1].item(), 1)
119+
120+
def test_lightning_datamodule(self):
121+
dm = SeqChromDataModule(
122+
train_wds="data/test_0.tar.gz",
123+
val_wds="data/test_0.tar.gz",
124+
test_wds="data/test_0.tar.gz",
125+
train_dataset_size=100,
126+
batch_size=3,
127+
num_workers=1,
128+
patch_last=False,
129+
)
130+
dm.setup()
131+
val_dl = iter(dm.val_dataloader())
132+
seq, chrom, target, label = next(val_dl)
133+
self.assertEqual(seq[0,0,3].item(), 1.0)
134+
self.assertEqual(chrom[0,0,3].item(), 4.0)
135+
self.assertEqual(target[0].item(), 0.0)
136+
self.assertEqual(label[1].item(), 1)
137+
138+
def test_lightning_datamodule_transform(self):
139+
dm = SeqChromDataModule(
140+
train_wds="data/test_0.tar.gz",
141+
val_wds="data/test_0.tar.gz",
142+
test_wds="data/test_0.tar.gz",
143+
transforms={"seq": test_seq_transform,
144+
"chrom": test_chrom_transform,
145+
"target": test_target_transform},
146+
train_dataset_size=100,
147+
batch_size=3,
148+
num_workers=1,
149+
patch_last=False,
150+
)
151+
dm.setup()
152+
val_dl = iter(dm.val_dataloader())
153+
seq, chrom, target, label = next(val_dl)
154+
155+
self.assertEqual(seq[0,0,3].item(), 2.0)
156+
self.assertAlmostEqual(chrom[0,0,3].item(), 4.0/3)
157+
self.assertEqual(target[0].item(), 0.0)
158+
self.assertEqual(label[1].item(), 1)
159+
160+
def test_seq_transform(seq):
161+
return seq + 1
162+
163+
def test_chrom_transform(chrom):
164+
return chrom / 3
165+
166+
def test_target_transform(target):
167+
return target * 3
168+
169+
if __name__ == "__main__":
170+
unittest.main()

0 commit comments

Comments
 (0)